Git Product home page Git Product logo

jaxformer's Introduction

Jaxformer

JAX library for training of large language models with data and model parallelism based on the pjit() operator on TPU-v3/v4.

Citation

Please cite:

@misc{Jaxformer,
  title={Jaxformer: A minimal library for training LLMs on TPU},
  author={Nijkamp, Erik},
  howpublished = {\url{https://github.com/salesforce/jaxformer}},
  year={2022}
}

Acknowledgments: Ben Wang, James Bradbury, Zak Stone, Bo Pang.

Models

CodeGen

350M

gs://sfr-codegen-research/checkpoints/codegen-350M-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-350M-multi/150000
gs://sfr-codegen-research/checkpoints/codegen-350M-mono/150000

2B

gs://sfr-codegen-research/checkpoints/codegen-2B-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-2B-multi/150000
gs://sfr-codegen-research/checkpoints/codegen-2B-mono/100000

6B

gs://sfr-codegen-research/checkpoints/codegen-6B-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-6B-multi/100000
gs://sfr-codegen-research/checkpoints/codegen-6B-mono/140000

Sanity TPU

import jax
jax.devices()
device_count = jax.device_count()
local_device_count = jax.local_device_count()
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)

gcloud compute tpus tpu-vm ssh [email protected] --zone=us-east1-d --internal-ip --worker=all --command="pip install 'jax[tpu]==0.3.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
gcloud compute tpus tpu-vm scp test.py [email protected]:/home/erik.nijkamp/ --zone=us-east1-d --internal-ip --worker=all
gcloud compute tpus tpu-vm ssh [email protected] --zone=us-east1-d --internal-ip --worker=all --command="python3 /home/erik.nijkamp/test.py"

Training

Mode 1: CPU local

brew install [email protected]
apt install --yes python3.9 python3.9-venv

git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer

python3.9 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt

python3 -m jaxformer.train --config config/debug_cpu.json

Mode 2: TPU local

gcloud compute tpus list --zone=europe-west4-a

gcloud compute tpus tpu-vm delete sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --quiet

gcloud compute tpus tpu-vm create sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --accelerator-type=v3-8 --version=v2-alpha

gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --project <project> --worker 0

export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>

git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer

./jaxformer/env/env_tpu_v3.sh
pip install -r requirements.txt

source .venv/bin/activate

python3
import jax
jax.devices()
quit()

python3 -m jaxformer.train --config config/debug_tpu_v3_8.json

Mode 3: TPU remote

gcloud beta compute --project=<project> instances create sfr-<username>-cpu-small-us-east1-d-1 --zone=us-east1-d --machine-type=e2-standard-4 --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account=<account> --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --image=ubuntu-minimal-2004-focal-v20210720 --image-project=ubuntu-os-cloud --boot-disk-size=50GB --boot-disk-type=pd-balanced --boot-disk-device-name=sfr-cpu-small --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any

gcloud beta compute ssh sfr-<username>-cpu-small-us-east1-d-1 --project=<project> --zone=us-east1-d

sudo apt update
sudo apt install --yes git screen python3.9 python3.9-venv

screen -S codegen_350M_nl

curl https://sdk.cloud.google.com | bash
source ~/.bashrc
gcloud init
ssh-keygen -t rsa -f ~/.ssh/google_compute_engine -N ''

export WANDB_API_KEY=<secret>
export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>

git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer

python3.9 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt

python3 -m jaxformer.train --config config/codegen_350M_nl.json

gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-64-us-east1-d-1 --zone us-east1-d --internal-ip --worker=0

Fine-tuning

TPU fine-tune

gcloud beta compute --project=<project> instances create sfr-<username>-cpu-small-us-east1-d-1 --zone=us-east1-d --machine-type=e2-standard-4 --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account=<account> --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --image=ubuntu-minimal-2004-focal-v20210720 --image-project=ubuntu-os-cloud --boot-disk-size=50GB --boot-disk-type=pd-balanced --boot-disk-device-name=sfr-cpu-small --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any

gcloud beta compute ssh sfr-<username>-cpu-small-us-east1-d-1 --project=<project> --zone=us-east1-d

sudo apt update
sudo apt install --yes git screen python3.9 python3.9-venv

screen -S codegen_350M_mono

curl https://sdk.cloud.google.com | bash
source ~/.bashrc
gcloud init
ssh-keygen -t rsa -f ~/.ssh/google_compute_engine -N ''

export WANDB_API_KEY=<secret>
export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>

git clone https://<username>:<secret>@github.com/salesforce/jaxformer.git/
cd jaxformer

python3.9 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt

python3 -m jaxformer.train --config config/codegen_350M_multi.json

gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-64-us-east1-d-1 --zone us-east1-d --internal-ip --worker=0

A100 fine-tune

apt install python3.8 python3.8-venv python3.8-dev

curl https://sdk.cloud.google.com | bash
source ~/.bashrc
gcloud init

export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials/<username>/adc.json
export GCLOUD_PROJECT=<project>

python3.8 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
pip install transformers==4.21.1 datasets==1.16.1 deepspeed==0.7.0 tensorflow-cpu==2.5.0

pip install -e .

deepspeed --num_gpus=1 jaxformer/hf/train.py

Conversion

python3 -m jaxformer.hf.convert --config=config/codegen_1B_mono.json --step=150000

Features

v1

  • Data
    • Stateful resumable data loading based on tfrecords without skip()
  • TPU
    • Provisioning of TPU clusters and virtual environment
    • Code paths for both TPU-v3 and TPU-v4
    • ...
  • Parallelism
    • Push-based single port TCP/IP protocol for orchestration and data-parallelism
    • Megatron pjit() based sharding pattern across TPU boards for up to 6B parameter LLMs
    • xmap() emulation mode through pjit() sharding
    • Distributed checkpointing with full state recovery
    • scan() for time-efficient jit'ing
    • ...
  • Debugging
    • Abstraction layer for local/remote workers
    • Local CPU debugging with TPU emulation
    • Mock data iterators
    • ...
  • Training
    • Fully resumable state and checkpointing
    • WandB integration
    • ...

jaxformer's People

Contributors

dependabot[bot] avatar enijkamp avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jaxformer's Issues

TPU finetuned model corrupts

Hi, thank you for your awesome work!

I would like to finetune the 350M model, using tpu remote mode on a single tpu v3-8 node. I used the config file https://github.com/salesforce/jaxformer/blob/main/config/codegen_350M_multi.json, and switched data_train_set to my own dataset with only ~100 texts. I finetuned the model for a few epochs and converted it to pytorch for sampling.

However, the finetuned model seems to be corrupted and repeates garbage output such as 'TETETETE', failing to give any meaningful generations. Even though the training loss goes to around 0.01 quickly, the model cannot repeat the content in the training set. I searched hyperparameters such as lr and training steps, but the problem persists.

It seems no matter what the model is trained on, the weights get corrupted quickly. Does this mean there is a discrepency between the model you pretrained and the code you published in this repo? Could you please release the original code you trained on as well?

cc: @amerine @mgomes @awaterma @ekashida @paultyng

Can we use Jaxformer for Nvidia A100 GPU currently?

Awesome. I am curious about JAxformer. This package indicates that Jaxformer functions in the TPU-V3/V4 environment.
Nonetheless, I retrieved the NVIDIA GPU A100 FINE-TUNE from the readme.md file.

I'm unsure if the JAXFOMER library can be used to perform Fine-Tune in the GPU A100 environment. Only the JSON file for the CPU and TPU was present in the ./config/ directory'. @enijkamp , Where can I get an example JSON file for the NVIDIA A100 GPU card?

Thanks a lot.

BRs,
Geunsik Lim.

Fine-tuning on conversations (format of conversations)

Hello

I have a dataset consisting of dialogues between two people which I would like to use for fine-tuning GPT-J. Please see below for two example dialogues. The dialogues vary in length and can be longer than the examples.

Is the format of the conversations ok? For fine-tuning, should I just concatenate all conversations into one big file or do I have to use a separator between the conversations (if yes, which separator)?

First Dialogue:

user1:
Hey there. What’s up?

user2:
Not much, just hanging out. What about you?

user1:
Just thinking about what I’m going to do this weekend. You?

user2:
Probably just relaxing. What do you have planned?

user1:
I’m thinking about going to the beach. It’s supposed to be nice this weekend.

user2:
That sounds like a great plan! Have you been to the beach recently?

user1:
Not in a while. It would be nice to get out and enjoy the sun.

user2:
Definitely! I’m sure it’ll be a great time. Do you have any other ideas for the weekend?

Second Dialgoue:

user1:
Good morning. What is your profession?

user2:
Good morning. I’m an accountant. What about you?

user1:
I’m a software engineer. How long have you been an accountant?

user2:
I’ve been an accountant for about five years now. What about you? How long have you been a software engineer?

user1:
I’ve been a software engineer for three years. What do you like most about accounting?

user2:
I like how challenging it can be. There’s always something to learn or something new to figure out. What do you like most about software engineering?

user1:
I like how creative it can be. I get to come up with new ideas and new ways of solving problems. It’s a great feeling when you can come up with something that works.

350M Mono not found

ckpt.json not found in gs://sfr-codegen-research/checkpoints/codegen-350M-mono/150000, but it is found in gs://sfr-codegen-research/checkpoints/codegen-350M-multi/150000.

Request for training configuration of CodeGen 16B

I want to finetune the 16B scale codegen checkpoint using TPU.

In the config directory, there is no configuration for that.

Could you share about the configuration? or some documentation for scaling model parameter?

Out-of-memory running with Deepspeed

I am trying to start a training session (Codegen 16B mono) with one GPU (RTX 3090) using the Deepspeed training script.

After model initialization, during Deepspeed initialization I get the following:

RuntimeError: CUDA out of memory. Tried to allocate 288.00 MiB (GPU 0; 24.00 GiB total capacity; 23.19 GiB already allocated; 0 bytes free; 23.21 GiB reserved in total by PyTorch)

Running jaxformer/blob/main/jaxformer/hf/train.py with the command deepspeed --num_gpus=1 jaxformer/hf/train.py

I also set debug_mock_data to True just to see if I can get it starting.

No matter how I try to tweak the parameters I always end up with OOM.

Also tried setting the PYTORCH_CUDA_ALLOC_CONF environment variable as well as running a lower version of torch (pip3 install torch==1.8.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html)

Any idea on how to solve this?

How to load codegen models in jaxformer locally??

I've tried to extended the local master with load_ckpt function, but ran into trouble with:

assert jax.process_count() == ckpt['process_count']

Any pointers for this?? Do we've to use jax.distributed.initialize or is there another way for doing it with a single process ??

Edit: Conversion script for jax to hf model? Does one can adopt gptj script of conversion??

Setup: single node with 8 A100s

cc @enijkamp

[Mismatched_sizes] Got mismatched_size exception when loading the finetuned model

Thank you for sharing the finetuning scripts for CodeGen.
However, I encountered a problem when attempting to load the finetuned model using the following code, where pretrain_dir refers to the path of the pytorch_model.bin and config.json.

tokenizer = transformers.AutoTokenizer.from_pretrained("Salesforce/codegen-350M-multi")
model = transformers.CodeGenForCausalLM.from_pretrained(pretrain_dir,config=os.path.join(pretrain_dir,"config.json"))   

An exception was thrown:

Traceback (most recent call last):
  File "/home/User/code-models/get_retrained_model_distribution.py", line 54, in <module>
    tokenizer, retrained_model = model_utils.load_retrained_model(f"output/finetune_codegen/20230316-{args.ds_type}-Epoch30/final_checkpoint-{args.ds_type}-1",model_name)
  File "/home/User/code-models/utils/model_utils.py", line 68, in load_retrained_model
    model = transformers.CodeGenForCausalLM.from_pretrained(pretrain_dir,config=os.path.join(pretrain_dir,"config.json"), local_files_only=True)    
  File "/home/User/.conda/envs/transformers/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2379, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/User/.conda/envs/transformers/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2695, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for CodeGenForCausalLM:
        size mismatch for transformer.h.0.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.0.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.0.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.0.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.1.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.1.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.1.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.1.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.2.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.2.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.2.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.2.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.3.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.3.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.3.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.3.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.4.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.4.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.4.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.4.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.5.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.5.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.5.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.5.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.6.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.6.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.6.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.6.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.7.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.7.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.7.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.7.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.8.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.8.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.8.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.8.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.9.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.9.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.9.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.9.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.10.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.10.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.10.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.10.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.11.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.11.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.11.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.11.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.12.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.12.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.12.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.12.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.13.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.13.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.13.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.13.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.14.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.14.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.14.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.14.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.15.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.15.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.15.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.15.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.16.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.16.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.16.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.16.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.17.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.17.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.17.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.17.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.18.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.18.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.18.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.18.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for transformer.h.19.attn.qkv_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([3072, 1024]).
        size mismatch for transformer.h.19.attn.out_proj.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
        size mismatch for transformer.h.19.mlp.fc_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([4096, 1024]).
        size mismatch for transformer.h.19.mlp.fc_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([1024, 4096]).
        size mismatch for lm_head.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([51200, 1024]).
        size mismatch for lm_head.bias: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([51200]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Adding ignore_mismatched_sizes=True can avoid the exception but make the model produce non-sense output. I am wondering how to properly load the model finetuned with the deepspeed scripts .

Thanks in advance. : )

[Suggestion]: Code Notes

Hello! Want to preface this 'issue' by sincerely thanking the owners of this repo & those that were responsible in creating the codegen model for taking the time to publish about your process, open source the model & create accompanying repos + publishing the models on HuggingFace. Your contributions to the community are invaluable.

I wanted to inquire about potentially adding code notes to some (or all) of the .py files included in this repo. While I know this is a big ask, I was curious about whether this might be a good place to put your fine tuned codegen_6b_mono model to work since it was trained on a large corpus of Python code, specifically. While program synthesis via NLP is a remarkable breakthrough, I also believe increasing the literacy of code for general observers is equally important.

Given all you guys have done in creating the relevant models, fine tuning & benchmarking them, documenting the process, publishing your findings & results as well as publishing the relevant code publicly, it would be borderline rapacious to demand the deployed code be annotated. Thus, I'm wondering if this could be considered an additional "real-world" use case for the fine tuned model you all created.

Please let me know what you guys think. I believe such a task is in line with this project's core ethos, which seems to be lowering the barriers to entry for programming or development, whether that be through an expedited workflow for experienced programmers or providing a 'bridge' for those that have a strong semantic understanding of programming / programming task, but lack the technical knowledge to iterate the necessary code from scratch. In specific, the latter scenario represents a democratization of involvement in the coding and programming process.

In that same vein, I believe potentially leveraging this model to append / annotate published code will promote a more comprehensive understanding among all those that interact with it - which can ultimately lead to reduced mistakes, errors, and misconfigurations all while also saving developer's time in answering questions or clarifying certain misunderstandings that may otherwise be made clear through this effort.

6.1B Config Produces 7B Parameters

This is a question rather than a bug or issue.

In the paper I see these parameters for the 6.1B parameter model:
Screenshot 2022-10-07 at 22 56 01

I see the same parameters reflected in configs/codegen_6B_nl.json.

However, when I run the model with those config parameters, I see 7B parameters rather than 6.1B:

Model params_num: 7064217600, params_size: 28256.87MB
Model params_num_total: 7064217600, params_size_total: 28256.87MB
{'params_num': 7064217600, 'params_size': 28256870400, 'params_num_total': 7064217600, 'params_size_total': 28256870400}

The 350M and 2.7B parameter config files produce the expected number of parameters, but this one seems about 1B more than 6B?

facing issue while running jaxformer locally.

Hi I am having issue running the jaxformer locally. I am trying to run it in windows environment and after installing so many libraries I am struck with deepspeed installation. Below is the error I am getting. Someone please help me.

PS C:\Users\xxx\Desktop\jaxformer-main>pip install deepspeed==0.7.0
Collecting deepspeed==0.7.0
Using cached deepspeed-0.7.0.tar.gz (629 kB)
Preparing metadata (setup.py) ... error
error: subprocess-exited-with-error

× python setup.py egg_info did not run successfully.
│ exit code: 1
╰─> [17 lines of output]
Traceback (most recent call last):
File "", line 2, in
File "", line 34, in
File "C:\Users\xxx\AppData\Local\Temp\pip-install-k2dvd7mz\deepspeed_6053b117392a43fd86418730c0157507\setup.py", line 166, in
ext_modules.append(builder.builder())
^^^^^^^^^^^^^^^^^
File "C:\Users\xxx\AppData\Local\Temp\pip-install-k2dvd7mz\deepspeed_6053b117392a43fd86418730c0157507\op_builder\builder.py", line 604, in builder
assert_no_cuda_mismatch()
File "C:\Users\xxx\AppData\Local\Temp\pip-install-k2dvd7mz\deepspeed_6053b117392a43fd86418730c0157507\op_builder\builder.py", line 88, in assert_no_cuda_mismatch
cuda_major, cuda_minor = installed_cuda_version()
^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\xxx\AppData\Local\Temp\pip-install-k2dvd7mz\deepspeed_6053b117392a43fd86418730c0157507\op_builder\builder.py", line 40, in installed_cuda_version
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
^^^^^^^^^^^^^^^^^^^^^
AssertionError: CUDA_HOME does not exist, unable to compile CUDA op(s)
[WARNING] Torch did not find cuda available, if cross-compiling or running with cpu only you can ignore this message. Adding compute capability for Pascal, Volta, and Turing (compute capabilities 6.0, 6.1, 6.2)
DS_BUILD_OPS=1
[end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

Cannot find 16B jax checkpoint

Would you be able to provide the 16B jax checkpoint? we cannot find it under the gs://sfr-codegen-research/checkpoints/

Is the BigQuery dataset public available?

In the paper it was mentioned

The multi-lingual dataset BIGQUERY is a subset of Google’s publicly available BigQuery dataset,
which consists of code (under open-source license) in multiple programming languages. 

Just wondering is this subset publicly available?

can't find paper

I can't find the paper referenced in the README:
@Article{Jaxformer,
title={Jaxformer: A minimal library for training LLMs on TPU},
author={Nijkamp, Erik},
journal={arXiv},
year={2022}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.