Git Product home page Git Product logo

Comments (8)

Xingzhi107 avatar Xingzhi107 commented on September 15, 2024 1

It's not clear to me what input and output that need to be constrained means
The input and output features have to be of the same length for the experts.

This means the feature dimension of input / output tensors of experts should be equal. I think your code fulfills this requirement.

I am not able to reproduce your issue with your code provided using randn or ones as input data. Can you please provide more information about the error? E.g. the shape of x that you print in your expert module. You may also add -g flag to cxx_flags in setup.py, recompile fastmoe, and run the program with CUDA_LAUNCH_BLOCKING=1 to see which line of cuda code gives this error.

Also, you can try turning off some features, for example FMOE_FASTER_SHADOW_ENABLE=0 or FMOE_FASTER_GROUP_SIZE=4, and see if any of these changes can bypass the error. If so, we will be able to further inspect specific functions.

Thank you for your reply!
I suspected that my expert didn't write it right, so I used LinearExpert in fastmoe again,but the error is same
I used Fmoe in a transformer and my code is as follows
`class TorchTransformerBlock(nn.Module):
def init(self, layer_id: int, args: ModelArgs):

    super().__init__()
    self.n_heads = args.n_heads
    self.dim = args.dim
    self.head_dim = args.dim // args.n_heads
    self.attention = TorchAttention(args)
    self.feed_forward = TorchFFN(
        dim=args.dim,
        hidden_dim=4 * args.dim,
    )
    self.layer_id = layer_id
    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(
    self,
    x: torch.Tensor,
    start_pos: int,
    freqs_cis: torch.Tensor,
    mask: Optional[torch.Tensor],
):
    """
    Perform a forward pass through the TransformerBlock.

    Args:
        x (torch.Tensor): Input tensor.
        start_pos (int): Starting position for attention caching.
        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
        mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.

    Returns:
        torch.Tensor: Output tensor after applying attention and feedforward layers.

    """
    h = x + self.attention.forward(
        self.attention_norm(x), start_pos, freqs_cis, mask
    )
    norm = self.ffn_norm(h) #([5, 1, 1024])
    print(norm.shape)
    out = h + self.feed_forward.forward(norm)
    return out`

`class MoETorchTransformerBlock(TorchTransformerBlock):
def init(self, layer_id: int, args: ModelArgs):
super().init(layer_id, args)

    self.attention = TorchAttention(args)
    assert args.moe["num_experts"] % args.num_gpus == 0, "num_experts must be divisible by num_gpus"
    # print(int(os.environ['WORLD_SIZE']))
    self.feed_forward = FastMoe (
             num_expert=args.moe["num_experts"],
             d_model = args.dim,
             d_hidden=args.hidden_dim,
             activation=torch.nn.SiLU(),
             world_size =int(os.environ['WORLD_SIZE']),
             top_k = args.moe["num_experts_per_tok"],`

`class FastMoe(FMoE):
def init(self,
num_expert=4,
d_model = 1024,
d_hidden=4096,
activation=torch.nn.SiLU(),
world_size =1,
top_k = 2,
# moe_group = 1,
):
# def one_expert(d_model):
# return Expert( d_model, d_hidden)
# expert = one_expert
super().init(num_expert, d_model, world_size,
top_k=top_k,expert=LinearExpert)
# self.mark_parallel_comm()

def forward(self, inp: torch.tensor):

    original_shape = inp.shape
    #print("original_shape:",original_shape) #[bsz,seq,d]
    inp = inp.reshape(-1, self.d_model) #[bsz*seq,d]
    output = super().forward(inp)

    return output.reshape(original_shape)`

When I don't turn on smart schedule,no errors occurred,but when I add FMOE_FASTER_SCHEDULE_ENABLE=1
attention: torch.Size([5, 1, 1024]) torch.Size([5, 1, 1024]) attention: torch.Size([5, 1, 1024]) torch.Size([5, 1, 1024]) [ubuntu:2697 :0:2697] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x7f817590a600) [ubuntu:2698 :0:2698] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x7fee15d0a600) ==== backtrace (tid: 2697) ====
the attention's output shape is output= output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
torchrun --standalone --nproc_per_node=4 tools/example.py -m ./ckpts -t ckpts/tokenizer.model it passed
FMOE_FASTER_SCHEDULE_ENABLE=1 torchrun --standalone --nproc_per_node=4 tools/example.py -m ./ckpts -t ckpts/tokenizer.model it failed

from fastmoe.

Xingzhi107 avatar Xingzhi107 commented on September 15, 2024 1

stored_models_ is the output of the policy_fn, which should be a boolean tensor on CPU. As you report that you cannot access its value, I am wondering if you are setting the default device of PyTorch to GPU, which may be problematic in current FastMoE.

Yes, indeed for this reason, thank you very much for your help!!!!

from fastmoe.

laekov avatar laekov commented on September 15, 2024

It's not clear to me what input and output that need to be constrained means
The input and output features have to be of the same length for the experts.

This means the feature dimension of input / output tensors of experts should be equal. I think your code fulfills this requirement.

I am not able to reproduce your issue with your code provided using randn or ones as input data. Can you please provide more information about the error? E.g. the shape of x that you print in your expert module. You may also add -g flag to cxx_flags in setup.py, recompile fastmoe, and run the program with CUDA_LAUNCH_BLOCKING=1 to see which line of cuda code gives this error.

Also, you can try turning off some features, for example FMOE_FASTER_SHADOW_ENABLE=0 or FMOE_FASTER_GROUP_SIZE=4, and see if any of these changes can bypass the error. If so, we will be able to further inspect specific functions.

from fastmoe.

Xingzhi107 avatar Xingzhi107 commented on September 15, 2024

I'm sorry, I have one more question I would like to ask,
I also have a puzzle in global_policy function

def global_policy(local_expert_count, _gec, num_expert, world_size):
    r"""
    This is the policy for two-layer MLPs, using the formula in the PPoPP paper.
    A few parameters are used in this policy.
    * `d_model`: feature length of the MLP input and output.
    * `alpha`: the ratio of the MLP's hidden size to `d_model`.
    * `bw_net`: bandwidth of the network (GBps)
    * `bw_mm`: computation throughput of performing GeMM (FLOPs)
    """
    bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8)
    bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12)
    alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2)
    d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048)

    moe_group = get_moe_group()
    local_expert_count = local_expert_count.cuda()
    agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)]
    dist.all_gather(agecs, local_expert_count, group=moe_group)
    all_global_expert_count = torch.stack(agecs)

    # TODO: data type other than float
    data_size = 4

    fwd_expert_counts = all_global_expert_count.sum(1).cpu()
    B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True)

if the result that local_expert_count gets on each card (worldsize) is the same or different, because I found that the result I got after gather became exactly the same number after sum, so that the resulting res was an all-false tensor
if local_expert_count should be diffierent,Maybe it's because I'm confusing some of the concepts that cause the miscalculations,local_expert_count it was calculated from here fmoe_cuda.expert_count(gate, local_expert_count)
In addition, why does the all false tensor res will lead to the Segmentation fault
Can you give me some advice, thanks
local

from fastmoe.

laekov avatar laekov commented on September 15, 2024

if the result that local_expert_count gets on each card (worldsize) is the same or different

local_expert_count differs on each GPU, because it includes the counters of samples in the local batch that goes to each expert.

why does the all false tensor res will lead to the Segmentation fault Can you give me some advice, thanks !

res in this function indicates which experts to be shadowed. All false in res means that no expert is being shadowed, which is a common case when the workload is relatively balanced across the experts. I do not think this can lead to a seg fault.

from fastmoe.

Xingzhi107 avatar Xingzhi107 commented on September 15, 2024

``> > if the result that local_expert_count gets on each card (worldsize) is the same or different

local_expert_count differs on each GPU, because it includes the counters of samples in the local batch that goes to each expert.

why does the all false tensor res will lead to the Segmentation fault Can you give me some advice, thanks !

res in this function indicates which experts to be shadowed. All false in res means that no expert is being shadowed, which is a common case when the workload is relatively balanced across the experts. I do not think this can lead to a seg fault.
Thank you very much for your answer, but I get an error when I execute stored_models_[i], there is no way to get its value, but it is possible to print its size,and I don't know what went wrong,its size is always 8,my stored_models is all false tensor which size is num_expert*world_size

std::vector<torch::Tensor> params;
    auto stored_models_ = stored_models.data_ptr<bool>();
    for (long i = 0; i < num_expert * n_workers; ++i) {
        if (stored_models_[i]) {
            torch::Tensor t = input_buf.new_empty({expert_size});
            if (i / num_expert == rank) {
                get_param_fn(t, i % num_expert);
            }
            params.push_back(t);
        }
    }

In addition local_expert_count it is calculated by the function in the FMOE, is it because my use of FMOE is written incorrectly, causing each local_expert_count to be the same?My num_expert is set 1,the world_size is set by os.environ['WORLD_SIZE'],my nnode is 1 and nproc_per_node=4

class Expert(nn.Module):
    def __init__(
        self,
        d_model, d_hidden,
        rank = 0,
    ):
        super().__init__()

        self.w1 = nn.Linear(
            d_model, d_hidden, bias=False
        )
        self.w2 = nn.Linear(
            d_hidden, d_model, bias=False
        )
        self.w3 = nn.Linear(
            d_model, d_hidden, bias=False
        )

    def forward(self, x, fec=None):
        # device = x.device
        # x = x.to(self.w1.weight.device)
        out = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # print(out.shape)
        return out

class FastMoe(FMoE):
    def __init__(self,
                 num_expert=4,
                 d_model = 1024,
                 d_hidden=4096,
                 activation=torch.nn.SiLU(),
                 world_size =1,
                 top_k = 2,
                 # moe_group = 1,
        ):
        def one_expert(d_model):
            return Expert( d_model,d_hidden)
        expert = one_expert
        super().__init__(num_expert, d_model, world_size,
                         top_k=top_k,expert=expert,gate=NaiveGate)
        self.mark_parallel_comm("dp")

    def forward(self, inp: torch.tensor):
        original_shape = inp.shape
        #print("original_shape:",original_shape) #[bsz,seq,d]
        inp = inp.reshape(-1, self.d_model) #[bsz*seq,d]


        # pdb.set_trace()
        output = super().forward(inp)

        return output.reshape(original_shape)

Thank you very much for your guidance again

from fastmoe.

laekov avatar laekov commented on September 15, 2024

but I get an error when I execute stored_models_[i], there is no way to get its value, but it is possible to print its size

stored_models_ is the output of the policy_fn, which should be a boolean tensor on CPU. As you report that you cannot access its value, I am wondering if you are setting the default device of PyTorch to GPU, which may be problematic in current FastMoE.

In addition local_expert_count it is calculated by the function in the FMOE, is it because my use of FMOE is written incorrectly, causing each local_expert_count to be the same?My num_expert is set 1,the world_size is set by os.environ['WORLD_SIZE'],my nnode is 1 and nproc_per_node=4

local_expert_count being the same is not unusual if your input on each GPU is the same. You may inspect the output of the gate module and see if the select the same experts.

world_size should be equal to the number of GPUs you use. So, in your case, world_size=4 should be correct.

from fastmoe.

laekov avatar laekov commented on September 15, 2024

Well, thank you very much for reporting this issue and ebugging. I think we should explicitly specify the device of tensors when we allocate them in our library. We will update the codebase before closing this issue.

from fastmoe.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.