Git Product home page Git Product logo

Comments (15)

Sakits avatar Sakits commented on July 20, 2024 5

Hi all, Thank you for your patience! We have added support for Bloom in the dev/more_models branch. You are welcome to try it out and feel free to bring up any issues you might encounter.

It's worth noting when scaling the out_proj, we should just scale the v part in the attention. However, Bloom fuses qkv into a single Linear layer, and the fusion is carried out by concatenating qkv within each attention head, and then connecting all heads together. This logic is different from that of MPT.

This makes the implementation kind of complex. Based on our tests, the performance degradation caused by not scaling the out_proj in Bloom models is relatively minor, so we decided to skip scaling this layer to keep the code simple.

If you want to try scaling out_proj in Bloom models, you can try replacing the following two blocks of code into auto_scale.py.

@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales, num_heads=None):
    assert isinstance(fc1, nn.Linear)
    assert isinstance(fc2, nn.Linear)

    scales = scales.to(fc1.weight.device)
    if fc1.out_features == fc2.in_features * 3:
        fc1.weight.t_()
        org_shape = fc1.weight.shape
        fc1.weight.data = fc1.weight.data.reshape(org_shape[0] * num_heads, 3, -1)
        value = fc1.weight.data[:, 2, :].reshape(org_shape[0], -1)
        fc1.weight.data[:, 2, :] = value.div(scales.view(-1)).reshape(fc1.weight[:, 2, :].shape)
        fc1.weight.data = fc1.weight.data.reshape(org_shape).t_()
        
        if fc1.bias is not None:
            fc1.bias.data = fc1.bias.data.reshape(num_heads, 3, -1)
            value = fc1.bias.data[:, 2, :].reshape(-1)
            fc1.bias.data[:, 2, :] = value.div(scales.view(-1)).reshape(fc1.bias[:, 2, :].shape)
            fc1.bias.data = fc1.bias.data.reshape(-1)

    else:
        assert fc1.out_features == fc2.in_features
        
        fc1.weight.div_(scales.view(-1, 1))
        if fc1.bias is not None:
            fc1.bias.div_(scales.view(-1))

    fc2.weight.mul_(scales.view(1, -1))

    for p in fc1.parameters():
        assert torch.isnan(p).sum() == 0
    for p in fc2.parameters():
        assert torch.isnan(p).sum() == 0

def apply_scale(module, scales_list, input_feat_dict=None):
    for prev_op_name, layer_names, scales in scales_list:
        prev_op = get_op_by_name(module, prev_op_name)
        layers = [get_op_by_name(module, name) for name in layer_names]
        
        if isinstance(prev_op, nn.Linear):
            assert len(layers) == 1
            scale_fc_fc(prev_op, layers[0], scales, module.num_heads if hasattr(module, 'num_heads') else None)
        elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
            scale_ln_fcs(prev_op, layers, scales)
        elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu):
            new_module = ScaledActivation(prev_op, scales)
            set_op_by_name(module, prev_op_name, new_module)
            scale_gelu_fc(prev_op, layers[0], scales)
        else:
            raise NotImplementedError(
                f"prev_op {type(prev_op)} not supported yet!")
            
        # apply the scaling to input feat if given; prepare it for clipping
        if input_feat_dict is not None:  
            for layer_name in layer_names:
                inp = input_feat_dict[layer_name]
                inp.div_(scales.view(1, -1).to(inp.device))

from llm-awq.

tonylins avatar tonylins commented on July 20, 2024

Please refer to awq/quantize/pre_quant.py and awq/quantize/auto_scale.py to see how we support different LLM architectures. We will probably integrate BLOOM support soon.

from llm-awq.

Tracin avatar Tracin commented on July 20, 2024

Please refer to awq/quantize/pre_quant.py and awq/quantize/auto_scale.py to see how we support different LLM architectures. We will probably integrate BLOOM support soon.

I tried, and I got very bad results. I figured out there is something wrong with scales of self_attention.dense and mlp.4h_to_h
but I don`t know why, Can I get any help?

from llm-awq.

tonylins avatar tonylins commented on July 20, 2024

Hi, I think there might be some bug in the implementation that leads to low performance. There are a few potential issues:

  • Firstly, the BLOOM FFN uses a GeLU activation function, so we cannot fuse the scaling of fc2 into the activation function (we may need to add an extra scaling node, which can be fused in the kernel implementation side).
  • BLOOM uses a fused qkv linear, so when you scale the out_proj in attention, you should just scale the v part in the attention.

from llm-awq.

Tracin avatar Tracin commented on July 20, 2024

Hi, I think there might be some bug in the implementation that leads to low performance. There are a few potential issues:

  • Firstly, the BLOOM FFN uses a GeLU activation function, so we cannot fuse the scaling of fc2 into the activation function (we may need to add an extra scaling node, which can be fused in the kernel implementation side).
  • BLOOM uses a fused qkv linear, so when you scale the out_proj in attention, you should just scale the v part in the attention.

Thanks for your reply. You are right ! After I fixed these two issues(only scale v in attention and do not scale 4h_to_h), the result seems much more reasonable, but PPL is still like twice larger than FP16(70 for bloom560M and 24 for bloom3b), I have to go on...

from llm-awq.

songkq avatar songkq commented on July 20, 2024

@Tracin @tonylins Hi, I met the same problem. I'm confusing how to only scale the v part of a fused qkv linear.
I guess like this, right?

# def scale_fc_fc(fc1, fc2, scales):
if fc1.out_features == 3*fc2.in_features:
        fc1.weight[-fc2.in_features:].div_(scales.view(-1, 1))
        if fc1.bias is not None:
            fc1.bias[-fc2.in_features:].div_(scales.view(-1))

As for fc2, how to add an extra scaling node? Could you please give some advice?

from llm-awq.

Tracin avatar Tracin commented on July 20, 2024

@Tracin @tonylins Hi, I met the same problem. I'm confusing how to only scale the v part of a fused qkv linear. I guess like this, right?

# def scale_fc_fc(fc1, fc2, scales):
if fc1.out_features == 3*fc2.in_features:
        fc1.weight[-fc2.in_features:].div_(scales.view(-1, 1))
        if fc1.bias is not None:
            fc1.bias[-fc2.in_features:].div_(scales.view(-1))

As for fc2, how to add an extra scaling node? Could you please give some advice?

Yes, It is exactly what I wrote, And I simply skip fc2.

from llm-awq.

songkq avatar songkq commented on July 20, 2024

@Tracin Thanks. When skipping fc2, the quantization model has poor performance and generates lots of repetitive and meaningless answers. @tonylins Could you please give some kind advices for this issue?

from llm-awq.

moonlightian avatar moonlightian commented on July 20, 2024

Interested in support for Bloom, looking forward for further improvement for awq for Blooms~

from llm-awq.

moonlightian avatar moonlightian commented on July 20, 2024

@Tracin Thanks. When skipping fc2, the quantization model has poor performance and generates lots of repetitive and meaningless answers. @tonylins Could you please give some kind advices for this issue?

Hi, I would like to quantize Bloom and met the same problem. Could you please give some advice on organizing scales_list for BloomBlock which is settled inside _auuto_get_scale() function in auto_scale.py? Thanks a lot!

from llm-awq.

Niko-zyf avatar Niko-zyf commented on July 20, 2024

Hello, I tried to reproduce the code on the Bloom model, but the accuracy drops to zero on Lambada. I'm wondering when the official support for Bloom will be available.

from llm-awq.

tonylins avatar tonylins commented on July 20, 2024

Hi all, thanks for the interests in our work. We should be able to add support for BLOOM this week. Please stay tuned.

from llm-awq.

shaochangxu avatar shaochangxu commented on July 20, 2024

@Tracin Hi, would you please tell me why it works for llama since it also has silu in mlp, but not work for gelu ?

from llm-awq.

tonylins avatar tonylins commented on July 20, 2024

Hi @shaochangxu, the LLaMA model uses SwishGLU activation function, which is y = x1 * act(x2), so we can fuse the scaling into x1. While for GeLU activation y = act(x), it is hard to fuse if act is not linear.

from llm-awq.

shaochangxu avatar shaochangxu commented on July 20, 2024

Hi @shaochangxu, the LLaMA model uses SwishGLU activation function, which is y = x1 * act(x2), so we can fuse the scaling into x1. While for GeLU activation y = act(x), it is hard to fuse if act is not linear.

I see! Thanks for your replay!

from llm-awq.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.