Git Product home page Git Product logo

Comments (12)

awni avatar awni commented on July 23, 2024 1

In my work porting the Phi-3-vision model to MLX

PS that is a very cool project! Is it functional? Do you mind if I share it more broadly?

from mlx-examples.

JosefAlbers avatar JosefAlbers commented on July 23, 2024 1

Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...

@mustangs0786, I'm currently working on integrating su-RoPE scaling directly into the Phi-3 model and plan to submit a pull request (PR) soon. In the meantime, you can try this temporary workaround within the Attention module's init method:

class Attention(nn.Module):
    def __init__(self, args):
        # ... 
        if args.rope_scaling is not None:
            if args.rope_scaling["type"] == "linear":
                rope_scale = 1 / args.rope_scaling["factor"]
                self.rope = nn.RoPE(
                    args.head_dim,
                    traditional=args.rope_traditional,
                    base=args.rope_theta,
                    scale=rope_scale,
                )
            elif args.rope_scaling["type"] == "su":
                self.rope = Phi3SuScaledRotaryEmbedding(args.head_dim, args) 

Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?

@awni My pleasure, I'll begin working on the PR shortly.

At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.

That would be fantastic.

PS that is a very cool project!

Wow, thank you!

Is it functional?

The project is at this point functional in several key tasks, including image captioning, batched generation, LoRA training, and model/cache quantization. You can find more details in my README.md.

Do you mind if I share it more broadly?

That would be very kind of you, thank you so much!

from mlx-examples.

mustangs0786 avatar mustangs0786 commented on July 23, 2024

Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...

from mlx-examples.

awni avatar awni commented on July 23, 2024

Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?

At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.

from mlx-examples.

JosefAlbers avatar JosefAlbers commented on July 23, 2024

Oh, and the su-RoPE is a bit different from how it was when I originally posted it last week. It's now as following:

class Phi3SuScaledRotaryEmbedding:
    def __init__(self, dim, config, **kwargs):
        self.inv_freq_short = 1.0 / (mx.array(config.rope_scaling["short_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
        self.inv_freq_long = 1.0 / (mx.array(config.rope_scaling["long_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
        self.original_max_position_embeddings = config.original_max_position_embeddings
        self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))

    def _get_cos_sin(self, offset, L, pids):
        def _get_pids(offset, L, pids):
            if offset < 1:
                return pids
            return pids[:, -1][:, None] + offset - pids.shape[1] + 2 + mx.arange(L)[None, :]
        position_ids = mx.arange(offset, offset+L, dtype=mx.float32)[None] if pids is None else _get_pids(offset, L, pids)
        inv_freq = self.inv_freq_long if position_ids.max()+1 > self.original_max_position_embeddings else self.inv_freq_short
        inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
        position_ids_expanded = position_ids[:, None, :]
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)  
        emb = mx.concatenate([freqs, freqs], axis=-1)  
        cos = mx.cos(emb) * self.scaling_factor
        sin = mx.sin(emb) * self.scaling_factor
        return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) 

    def __call__(self, q, k=None, offset=0, pids=None):
        def _rotate_half(x):
            midpoint = x.shape[-1] // 2  
            x1, x2 = x[..., :midpoint], x[..., midpoint:]  
            return mx.concatenate([-x2, x1], axis = -1) 
        cos, sin = self._get_cos_sin(offset, q.shape[2], pids)
        return (q * cos) + (_rotate_half(q) * sin) if k is None else (q * cos) + (_rotate_half(q) * sin), (k * cos) + (_rotate_half(k) * sin)

from mlx-examples.

mustangs0786 avatar mustangs0786 commented on July 23, 2024

@JosefAlbers Hi i tried implementing,
`class Attention(nn.Module):
def init(self, args: ModelArgs):
super().init()

    dim = args.hidden_size
    self.n_heads = n_heads = args.num_attention_heads
    self.n_kv_heads = n_kv_heads = args.num_key_value_heads

    head_dim = args.hidden_size // n_heads
    self.scale = head_dim**-0.5

    op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
    self.qkv_proj = nn.Linear(dim, op_size, bias=False)
    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
    if args.rope_scaling is not None and args.rope_scaling["type"] == "linear":
        rope_scale = args.rope_scaling["factor"]
        self.rope = nn.RoPE(
        head_dim,
        traditional=args.rope_traditional,
        base=args.rope_theta,
        scale=rope_scale,)
    else:
        print("test")
        self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)`
        
   Phi3SuScaledRotaryEmbedding : using your code above

`File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:176, in (.0)
173 assert self.vocab_size > 0
174 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
175 self.layers = [
--> 176 TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
177 ]
178 self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:146, in TransformerBlock.init(self, args)
144 self.num_attention_heads = args.num_attention_heads
145 self.hidden_size = args.hidden_size
--> 146 self.self_attn = Attention(args)
147 self.mlp = MLP(args.hidden_size, args.intermediate_size)
148 self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:85, in Attention.init(self, args)
83 else:
84 print("deepak")
---> 85 self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)

File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:43, in Phi3SuScaledRotaryEmbedding.init(self, dim, config)
41 self.dim = dim
42 self.base = config.rope_theta
---> 43 self.short_factor = config.rope_scaling["short_factor"]
44 self.long_factor = config.rope_scaling["long_factor"]
45 self.original_max_position_embeddings = config.original_max_position_embeddings

TypeError: 'NoneType' object is not subscriptable`

from mlx-examples.

JosefAlbers avatar JosefAlbers commented on July 23, 2024

@mustangs0786, it turns out that incorporating su-RoPE into mlx-lm required a bit more work than initially expected. I've just submitted a Pull Request with a modified implementation that seems to work well for phi-3-mini-128k: #813

from mlx-examples.

sujantkumarkv avatar sujantkumarkv commented on July 23, 2024

though I'm hopeful our forth-coming fused attention will help there.

I was also thinking along the lines of having various attention implementations like fused attention etc... If this is already in works, can you link me to it?
or
suggest anything specific in this direction if its required?
thanks. cc @awni

from mlx-examples.

awni avatar awni commented on July 23, 2024

We are already working on fused attention. What other variations did you have in mind?

from mlx-examples.

sujantkumarkv avatar sujantkumarkv commented on July 23, 2024

Feel free to teach me here, not an expert at all.

  • I might have wrote in the wrong repo. I only see MHA MultiHeadAttention in the mlx-explore/mlx repo and thought we should have MultiQueryAttention, GroupedQueryAttention as well. Deepseekv2 also introduced MultiLatentHeadAttention i suppose.

  • as I found fused attention is probably this and you're working on cuda/triton implementation?

  • also, offtopic and maybe a dumb Q, but I see we can train decoder model in mlx like here, so maybe extending with some architecture changes, we can train Llama style natively on mlx right? it would a great addition to examples because currently we have inference & Lora mostly.

from mlx-examples.

awni avatar awni commented on July 23, 2024

Our fused attention support MQA and GQA as well.

you're working on cuda/triton implementation

from mlx-examples.

sujantkumarkv avatar sujantkumarkv commented on July 23, 2024

can you point me to the attention implementations in your work on fused attention please.. interested to dive and potentially help. cc @awni

from mlx-examples.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.