Comments (12)
In my work porting the Phi-3-vision model to MLX
PS that is a very cool project! Is it functional? Do you mind if I share it more broadly?
from mlx-examples.
Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...
@mustangs0786, I'm currently working on integrating su-RoPE scaling directly into the Phi-3 model and plan to submit a pull request (PR) soon. In the meantime, you can try this temporary workaround within the Attention module's init method:
class Attention(nn.Module):
def __init__(self, args):
# ...
if args.rope_scaling is not None:
if args.rope_scaling["type"] == "linear":
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
args.head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
elif args.rope_scaling["type"] == "su":
self.rope = Phi3SuScaledRotaryEmbedding(args.head_dim, args)
Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?
@awni My pleasure, I'll begin working on the PR shortly.
At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.
That would be fantastic.
PS that is a very cool project!
Wow, thank you!
Is it functional?
The project is at this point functional in several key tasks, including image captioning, batched generation, LoRA training, and model/cache quantization. You can find more details in my README.md.
Do you mind if I share it more broadly?
That would be very kind of you, thank you so much!
from mlx-examples.
Can you please help where i can fit in the libarary code. i have seen your code, but if i want to use original library and fit this code. can you please help thanks...
from mlx-examples.
Sorry for the delayed response. I definitely think we should add this to the model files as they are presumably incorrect now for long context. Would you mind sending a PR @JosefAlbers ?
At the moment running very long context might hit memory limitations, though I'm hopeful our forth-coming fused attention will help there.
from mlx-examples.
Oh, and the su-RoPE is a bit different from how it was when I originally posted it last week. It's now as following:
class Phi3SuScaledRotaryEmbedding:
def __init__(self, dim, config, **kwargs):
self.inv_freq_short = 1.0 / (mx.array(config.rope_scaling["short_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
self.inv_freq_long = 1.0 / (mx.array(config.rope_scaling["long_factor"], dtype=mx.float32) * config.rope_theta**(mx.arange(0, dim, 2, dtype=mx.float32) / dim))
self.original_max_position_embeddings = config.original_max_position_embeddings
self.scaling_factor = math.sqrt(1 + math.log(config.max_position_embeddings / config.original_max_position_embeddings) / math.log(config.original_max_position_embeddings))
def _get_cos_sin(self, offset, L, pids):
def _get_pids(offset, L, pids):
if offset < 1:
return pids
return pids[:, -1][:, None] + offset - pids.shape[1] + 2 + mx.arange(L)[None, :]
position_ids = mx.arange(offset, offset+L, dtype=mx.float32)[None] if pids is None else _get_pids(offset, L, pids)
inv_freq = self.inv_freq_long if position_ids.max()+1 > self.original_max_position_embeddings else self.inv_freq_short
inv_freq_expanded = mx.repeat(inv_freq[None, :, None], position_ids.shape[0], axis=0)
position_ids_expanded = position_ids[:, None, :]
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1)
def __call__(self, q, k=None, offset=0, pids=None):
def _rotate_half(x):
midpoint = x.shape[-1] // 2
x1, x2 = x[..., :midpoint], x[..., midpoint:]
return mx.concatenate([-x2, x1], axis = -1)
cos, sin = self._get_cos_sin(offset, q.shape[2], pids)
return (q * cos) + (_rotate_half(q) * sin) if k is None else (q * cos) + (_rotate_half(q) * sin), (k * cos) + (_rotate_half(k) * sin)
from mlx-examples.
@JosefAlbers Hi i tried implementing,
`class Attention(nn.Module):
def init(self, args: ModelArgs):
super().init()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear":
rope_scale = args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,)
else:
print("test")
self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)`
Phi3SuScaledRotaryEmbedding : using your code above
`File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:176, in (.0)
173 assert self.vocab_size > 0
174 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
175 self.layers = [
--> 176 TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
177 ]
178 self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:146, in TransformerBlock.init(self, args)
144 self.num_attention_heads = args.num_attention_heads
145 self.hidden_size = args.hidden_size
--> 146 self.self_attn = Attention(args)
147 self.mlp = MLP(args.hidden_size, args.intermediate_size)
148 self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:85, in Attention.init(self, args)
83 else:
84 print("deepak")
---> 85 self.rope = Phi3SuScaledRotaryEmbedding(head_dim, args)
File ~/virtual_env_all/mlx_env/lib/python3.9/site-packages/mlx_lm/models/phi3.py:43, in Phi3SuScaledRotaryEmbedding.init(self, dim, config)
41 self.dim = dim
42 self.base = config.rope_theta
---> 43 self.short_factor = config.rope_scaling["short_factor"]
44 self.long_factor = config.rope_scaling["long_factor"]
45 self.original_max_position_embeddings = config.original_max_position_embeddings
TypeError: 'NoneType' object is not subscriptable`
from mlx-examples.
@mustangs0786, it turns out that incorporating su-RoPE into mlx-lm required a bit more work than initially expected. I've just submitted a Pull Request with a modified implementation that seems to work well for phi-3-mini-128k: #813
from mlx-examples.
though I'm hopeful our forth-coming fused attention will help there.
I was also thinking along the lines of having various attention implementations like fused attention etc... If this is already in works, can you link me to it?
or
suggest anything specific in this direction if its required?
thanks. cc @awni
from mlx-examples.
We are already working on fused attention. What other variations did you have in mind?
from mlx-examples.
Feel free to teach me here, not an expert at all.
-
I might have wrote in the wrong repo. I only see MHA MultiHeadAttention in the
mlx-explore/mlx
repo and thought we should haveMultiQueryAttention
,GroupedQueryAttention
as well. Deepseekv2 also introducedMultiLatentHeadAttention
i suppose. -
as I found fused attention is probably this and you're working on cuda/triton implementation?
-
also, offtopic and maybe a dumb Q, but I see we can train decoder model in mlx like here, so maybe extending with some architecture changes, we can train Llama style natively on mlx right? it would a great addition to examples because currently we have inference & Lora mostly.
from mlx-examples.
Our fused attention support MQA and GQA as well.
you're working on cuda/triton implementation
-
No CUDA backend for MLX, everything is for Apple silicon
-
We have a transformer LM training example in this repo.
from mlx-examples.
can you point me to the attention implementations in your work on fused attention please.. interested to dive and potentially help. cc @awni
from mlx-examples.
Related Issues (20)
- LoRA tune ibm geanite 8b insteuct HOT 2
- Where can I get started to convert internvl model to mlx format? HOT 4
- M2 Ultra 192 GB fails to run while M3 Max 128GB can run HOT 3
- Inference shapes exception with Gemma 2 SPPO HOT 5
- Unlike the document, the code here didn't force a graph evaluation for the optimizer's parameters. HOT 1
- Can mlx_lm.fust model convert to Huggingface model? HOT 1
- Peak mem 201 GB running on M2 Ultra 192 GB, how is this possible? HOT 1
- Quantization causing tensor shape mismatch HOT 1
- gemma-2-27b-it-4bit generate only <pad> HOT 11
- grad-checkpoint makes trained tokens increase gradually HOT 2
- DoRa training is never activated
- Finetuning gemma-2-27b-8bits error HOT 1
- Support model with mlx - stable video diffusion
- conversion of custom transformer HOT 2
- support for mamba 2 (Codestral mamba) #859
- Classification Example HOT 1
- When I use mlx-community/clip-vit-base-patch32, the bug "FileNotFoundError: No safetensors found in mlx_model" happens. HOT 1
- Support for nanogpt (and gpt-j)
- Tokenizer with bos and eos token id sharing and "[WARNING] Example already has an EOS token appended" HOT 2
- install mlx-lm version 0.16.0 : ERROR: Could not find a version that satisfies the requirement mlx-lm==0.16.0 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlx-examples.