naklecha / llama3-from-scratch Goto Github PK
View Code? Open in Web Editor NEWllama3 implementation one matrix multiplication at a time
License: MIT License
llama3 implementation one matrix multiplication at a time
License: MIT License
I know that this project uses the Meta-Llama-3-8B model by default, but when I switch to Meta-Llama-3-70B-Instruct, an error occurs during token embedding. The error is due to a mismatch in the dimensions of the tok_embeddings.weight between Meta-Llama-3-8B and Meta-Llama-3-70B-Instruct, even though the vocab_size of both models is the same. Why is this happening?
RuntimeError Traceback (most recent call last)
Cell In[39], line 3
1 embedding_layer = torch.nn.Embedding(vocab_size, dim)
2 print(model["tok_embeddings.weight"].shape)
----> 3 embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
4 token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
5 token_embeddings_unnormalized.shape
RuntimeError: The size of tensor a (128256) must match the size of tensor b (16032) at non-singleton dimension 0
First, thank you for sharing this project with us!
Could you please add an explicit LICENSE
file to the repo so that it's clear
under what terms the content is provided, and under what terms user
contributions are licensed?
[...] without a license, the default copyright laws apply, meaning that you
retain all rights to your source code and no one may reproduce, distribute,
or create derivative works from your work. If you're creating an open source
project, we strongly encourage you to include an open source license.
Thanks!
I add a for loop for this program, and attach the next token to the end of prompt, it will not generate the <|end_of_text|> , just generate the same context again and again, for example, we use 2x8= as original prompt, but it will generate the sequence like this: ['<|begin_of_text|>', '2', 'x', '9', '=', '18', '\n', '2', 'x', '9', '=', '18', '\n', '2', 'x']
import torch
import transformers
import json
model = transformers.AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B')
tokenizer = transformers.AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B',model_max_length=512,padding_side="right",use_fast=False,)
dim = model.config.hidden_size
n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads
n_kv_heads = model.config.num_key_value_heads
vocab_size = model.config.vocab_size
#The following two lines of code seem to be useless
multiple_of = model.config.multiple_of if model.config.multiple_of else 1024
ffn_dim_multiplier = model.config.ffn_dim_multiplier if model.config.ffn_dim_multiplier else 1.3
norm_eps = model.config.rms_norm_eps
rope_theta = model.config.rope_theta
prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens = tokenizer.encode(prompt)
print(tokens)
tokens = torch.tensor(tokens)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]
print(prompt_split_as_tokens)
embedding_layer = torch.nn.Embedding(vocab_size, dim)
tok_embeddings_weight = model.model.embed_tokens.weight
embedding_layer.weight.data.copy_(tok_embeddings_weight)
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
token_embeddings_unnormalized.shape
def rms_norm(tensor, norm_weights):
return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights
#RoPE
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
zero_to_one_split_into_64_parts
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
freqs
freqs_for_each_token = torch.outer(torch.arange(17), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
freqs_cis.shape
final_embedding = token_embeddings_unnormalized
for layer in range(n_layers):
qkv_attention_store = []
layer_embedding_norm = rms_norm(final_embedding, model.model.layers[layer].post_attention_layernorm.weight)
q_layer = model.model.layers[layer].self_attn.q_proj.weight
q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
k_layer = model.model.layers[layer].self_attn.k_proj.weight
k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
v_layer = model.model.layers[layer].self_attn.v_proj.weight
v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
w_layer = model.model.layers[layer].self_attn.o_proj.weight.to(torch.bfloat16)
for head in range(n_heads):
q_layer_head = q_layer[head]
k_layer_head = k_layer[head//4]
v_layer_head = v_layer[head//4]
q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T).to(torch.bfloat16)
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5
mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
mask = torch.triu(mask, diagonal=1)
qk_per_token_after_masking = qk_per_token + mask
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
qkv_attention_store.append(qkv_attention)
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
#w_layer = model[f"layers.{layer}.attention.wo.weight"]
embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
embedding_after_edit = final_embedding + embedding_delta
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model.model.layers[layer].input_layernorm.weight)
w1 = model.model.layers[layer].mlp.gate_proj.weight
w2 = model.model.layers[layer].mlp.down_proj.weight
w3 = model.model.layers[layer].mlp.up_proj.weight
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
final_embedding = embedding_after_edit+output_after_feedforward
final_embedding = rms_norm(final_embedding, model.model.norm.weight)
final_embedding.shape
logits = torch.matmul(final_embedding[-1], model.lm_head.weight.T)
logits.shape
next_token = torch.argmax(logits, dim=-1)
next_token
>>>tensor(8168)
tokenizer.decode([next_token.item()])
>>>';&'
my question:Why is the result not 42?
Thanks for the awesome repository! After going through it step-by-step, I have a better understanding of Llama3 techniques, such as rotary position embedding, grouped key and value, etc.
I found that there might be a minor mistake regarding the skip-connection visualization: the corresponding code is in the section "WE FINALLY HAVE NEW EDITED EMBEDDINGS FOR EACH TOKEN AFTER THE FIRST LAYER" :
layer_0_embedding = embedding_after_edit+output_after_feedforward
layer_0_embedding.shape
As embedding_after_edit
instead of embedding_after_edit_normalized
is used, the visualization should be
Hello @naklecha,
I hope this message finds you well.
I am a contributor to the Datawhale community, specifically working on the LLMs-from-Scratch-CN project. Our goal is to recreate large language models from scratch, offering detailed discussions and implementations in the Chinese language. We have successfully completed the implementation of ChatGLM3 and have plans to implement other models such as Mamba, RWKV, Phi, MiniCPM, Qwen, among others.
We came across your impressive work on the Llama3-from-scratch repository. The detailed and precise implementation you have provided is exactly what we are looking for to include in our project. We believe that integrating your Llama3 codebase will not only enhance our project but also provide a valuable resource for the community.
Therefore, we would like to formally request your authorization to integrate your Llama3 implementation into our project. We assure you that we will provide proper attribution and adhere to any licensing terms you specify.
Please let us know if you have any conditions or requirements for this integration. We look forward to your positive response.
Best regards,
Ethan-Chen-Plus
Datawhale Community
In "using dot product of complex numbers to rotate a vector" section, the code "freqs_cis.shape" doesn't have its output, which may make reader confused.
In "multi head attention" section, the line "qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)" repeat twice.
Thank you.
Can't run on free colab due to not having adequate RAM.
Very interesting work. Thank you so much.
I've implemented a repo that supports the weights of the huggingface format on this basis. if anyone is interested, you can refer to it.
https://github.com/ZiQiangXie/llm-from-scratch
Can you create the training/fine tuning part from scratch, too?
That would make this complete.
Thanks for this great article!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.