Comments (7)
Btw, I think the bug is how we use argpartition
. It should receive ne-1
instead of ne
here
Basically argpartition
puts the kth
value in it's sorted spot (where kth
is zero-based). So to get the smallest two items you would do argpartition(a, kth=1)[:2]
. If kth
is the number of experts then it's out of bounds so argpartition
crashes.
from mlx-examples.
Thank you @hurongliang. The fix is in #515
from mlx-examples.
Could you share the command you ran to get that issue? (Also that argpartition error message looks a bit funky, we should probably fix it).
from mlx-examples.
In the current implementation of Moe, it cannot handle less than 2x3, so 2x2 won't work with TopK. I have encountered the similar error before but haven't had the chance to look into the details.
from mlx-examples.
Huh, so that’s generation with 2 experts? I can check it, maybe there is a bug in topk.
from mlx-examples.
cloudyu/Yi-34Bx2-MoE-60B-DPO
Yeah, there are 2 experts with 2 local experts in the model that don't work with the current Moe sparse block.
https://huggingface.co/cloudyu/Yi-34Bx2-MoE-60B-DPO/blob/main/config.json#L17-L20
from mlx-examples.
Could you share the command you ran to get that issue? (Also that argpartition error message looks a bit funky, we should probably fix it).
Step by step to get the issue.
- Download cloudyu/Yi-34Bx2-MoE-60B-DPO
huggingface-cli download --resume-download cloudyu/Yi-34Bx2-MoE-60B-DPO
- Convert weights to mlx format with 4bit quantized.
from mlx_lm import convert
convert('cloudyu/Yi-34Bx2-MoE-60B-DPO', mlx_path='mlx-community/Yi-34Bx2-MoE-60B-DPO-4bit-mlx', quantize=True, q_bits=4)
- Run
python test_hello.py mlx-community/Yi-34Bx2-MoE-60B-DPO-4bit-mlx -mlx
.
# file test_hello.py
from mlx_lm import load, generate
import argparse
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Convert models on huggingface hub to mlx format.")
parser.add_argument("repo_id", help="repo_id on huggingface.co")
parser.add_argument("-mlx", "--mlx", help="True if model is mlx format", action="store_true", default=False)
parser.add_argument("-chat", "--chat", help="True if model is a chat model", action="store_true", default=False)
parser.add_argument("--prompt", help="Prompt", default="hello")
args = parser.parse_args()
is_mlx = args.mlx
repo_id = args.repo_id
is_chat_model = args.chat
prompt = args.prompt
if is_mlx:
model, tokenizer = load(repo_id)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
print(response)
else:
if is_chat_model:
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True).to('mps')
model = model.eval()
response, history = model.chat(tokenizer, prompt, history=[])
print(response)
else:
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id)
input_ids = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**input_ids, max_new_tokens=10000)
print(tokenizer.decode(outputs[0]))
Console output
==========
Prompt: hello
Traceback (most recent call last):
File "/Users/hurongliang/git/llmresearch/test_hello.py", line 20, in <module>
response = generate(model, tokenizer, prompt=prompt, verbose=True)
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 214, in generate
for (token, prob), n in zip(
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/utils.py", line 157, in generate_step
logits, cache = model(y[None], cache=cache)
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 260, in __call__
out, cache = self.model(inputs, cache)
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 243, in __call__
h, cache[e] = layer(h, mask, cache[e])
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 209, in __call__
r = self.block_sparse_moe(self.post_attention_layernorm(h))
File "/Users/hurongliang/miniconda3/envs/llm/lib/python3.10/site-packages/mlx_lm/models/mixtral.py", line 160, in __call__
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
ValueError: [argpartition] Received invalid kth 2along axis -1 for array with shape: (1,2)
from mlx-examples.
Related Issues (20)
- Llama-3-8B-Instruct-Gradient-1048k-4bit not working? HOT 2
- Generating after LORA training CAN NOT Stop Properly HOT 3
- Issue with Fusing Models - Output is Bad HOT 2
- GatedRepoError: 401 Client Error; "You must be authenticated to access it." HOT 1
- [Feature Request] When generating using mlx_lm, specify data format HOT 2
- how to merge lora adapter to base model HOT 1
- delete and uninstall HOT 11
- KV Cache can only process more than self.step tokens if offset % step == 0 HOT 2
- Text to Speech MLX model. HOT 1
- SLM Example Code HOT 1
- Enhance load function to support model configuration editing HOT 1
- Support for full set of output formats - e.g. vtt, json and json-full HOT 2
- Whisper stutters HOT 8
- mlx 0.13 very slow with q8 and fp16 HOT 5
- Fine tuned a Mixtral-8x7B-Instruct-v0.1 model and unable to load with AutoModelForCausalLM HOT 1
- Phi-3-mini-4k-instruct : Failing to stop at <|end|> on generating the answer. HOT 5
- PaliGemma 4bit Quantization broken and Inference issues. HOT 27
- [Feature Request] Function Calling for mlx_lm.server HOT 4
- OS system requirement for mlx HOT 1
- 01-ai/Yi-1.5-9B-Chat got ValueError: Cannot instantiate this tokenizer from a slow version. HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlx-examples.