Git Product home page Git Product logo

gemma_pytorch's Introduction

Gemma in PyTorch

Gemma is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links:

This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU.

Updates

[April 9th] Support CodeGemma. You can find the checkpoints on Kaggle and Hugging Face [April 5] Support Gemma v1.1. You can find the v1.1 checkpoints on Kaggle and Hugging Face.

Download Gemma model checkpoint

You can find the model checkpoints on Kaggle here.

Alternatively, you can find the model checkpoints on the Hugging Face Hub here. To download the models, go the the model repository of the model of interest and click the Files and versions tab, and download the model and tokenizer files. For programmatic downloading, if you have huggingface_hub installed, you can also run:

huggingface-cli download google/gemma-7b-it-pytorch

Note that you can choose between the 2B, 7B, 7B int8 quantized variants.

VARIANT=<2b or 7b>
CKPT_PATH=<Insert ckpt path here>

Try it free on Colab

Follow the steps at https://ai.google.dev/gemma/docs/pytorch_gemma.

Try it out with PyTorch

Prerequisite: make sure you have setup docker permission properly as a non-root user.

sudo usermod -aG docker $USER
newgrp docker

Build the docker image.

DOCKER_URI=gemma:${USER}

docker build -f docker/Dockerfile ./ -t ${DOCKER_URI}

Run Gemma inference on CPU.

PROMPT="The meaning of life is"

docker run -t --rm \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run.py \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    --prompt="${PROMPT}"
    # add `--quant` for the int8 quantized model.

Run Gemma inference on GPU.

PROMPT="The meaning of life is"

docker run -t --rm \
    --gpus all \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run.py \
    --device=cuda \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    --prompt="${PROMPT}"
    # add `--quant` for the int8 quantized model.

Try It out with PyTorch/XLA

Build the docker image (CPU, TPU).

DOCKER_URI=gemma_xla:${USER}

docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}

Build the docker image (GPU).

DOCKER_URI=gemma_xla_gpu:${USER}

docker build -f docker/xla_gpu.Dockerfile ./ -t ${DOCKER_URI}

Run Gemma inference on CPU.

docker run -t --rm \
    --shm-size 4gb \
    -e PJRT_DEVICE=CPU \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run_xla.py \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    # add `--quant` for the int8 quantized model.

Run Gemma inference on TPU.

Note: be sure to use the docker container built from xla.Dockerfile.

docker run -t --rm \
    --shm-size 4gb \
    -e PJRT_DEVICE=TPU \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run_xla.py \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    # add `--quant` for the int8 quantized model.

Run Gemma inference on GPU.

Note: be sure to use the docker container built from xla_gpu.Dockerfile.

docker run -t --rm --privileged \
    --shm-size=16g --net=host --gpus all \
    -e USE_CUDA=1 \
    -e PJRT_DEVICE=CUDA \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run_xla.py \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    # add `--quant` for the int8 quantized model.

Tokenizer Notes

99 unused tokens are reserved in the pretrained tokenizer model to assist with more efficient training/fine-tuning. Unused tokens are in the string format of <unused[0-98]> with token id range of [7-105].

"<unused0>": 7,
"<unused1>": 8,
"<unused2>": 9,
...
"<unused98>": 105,

Disclaimer

This is not an officially supported Google product.

gemma_pytorch's People

Contributors

danielhanchen avatar eltociear avatar joselpart avatar k-nar avatar mddct avatar michaelmoynihan avatar mon-ius avatar osanseviero avatar pengchongjin avatar qubitium avatar r-gheda avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gemma_pytorch's Issues

Is it possible to load 7b-it using quantization config

Newbie here.
7b-it model could be loaded in a low memory device via quantization config without using quant version of model using BitsAndBytes like below in huggingface's AutoModelForCausalLM .

quantization_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"/kaggle/input/gemma/transformers/7b-it/2",
device_map = "auto",
trust_remote_code = True,
quantization_config=quantization_config,
)
Whether such type of loading is feasible in your current package?

I got empty result while using 7b-it model

I use WSL2 on Windows 11 to run gemma_pytorch. The device is i9-13900 + RTX A6000 and I use the 7b-it model. But When I try to run the Gemma interface, I always get an empty result. What might be the reason and how can I solve it?

/opt/conda/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Model loading done
======================================
PROMPT: The meaning of life is
RESULT:
======================================

Quantised weights are bfloat16 not int8

According to the Kaggle model card it is a 7B int8 quantized parameter base model.

I'm trying to run the model using torch.uint8 dtype for the quantised version. However, the parameters are in torch.bfloat16 dtype. Is it possible to get them or transform the bfloat16 ones to int8?

How to save memory when loading weights?

OS: Windows 11 22631.3296
Python: 3.11.8
PyTorch: 2.2.1 (installed in conda env)
CUDA: 12.1 (installed in conda env)
NV Driver: 551.76
Gemma Model: 7b-it

I was trying to run the inference. Before I started, I have used 6GB memory and had 26GB free.

I obseved that when the code runs to the load_weights function, the memory usage went up to 98% of my total 32GB RAM, lasted for about a minute and then dropped to normal. In that time, I haven't called the to(device) function in the next line.

Form the Task Manager, at the time of high usage, I see the python.exe took about 28GB Working set, while the active private working set was about 14GB. And at that time, the page file of Windows was involved to keep the system working.

Taskmgr_NyaKIArP30

However, the 7B-it model (16bit float) should not exceed 16GB size. Allocating 28GB of memory in this process is pointless.
Remember what I said above, the memory usage eventually dropped to normal without calling to(device)? This just showed that it doesn't require that much memory.

Sorry, I don't know how Python or PyTorch manage memory. But I'm wondering if it's possible to improve this line for smoothing memory usage spikes?

How to use gemma for multi-round conversations

Thank a lot for your great work๏ผ
I deployed gemma-2b locally. I would like to understand how to have multiple rounds of dialog effectively.

I searched the internet and found that I could type in previous conversations to get answers for the next round. But I don't know exactly how it works inside Gemma. I hope to get your pointers or if you can recommend some existing tutorials.

I'm not a native English speaker and may have some grammatical problems. Thank you for your attention.

why some prompt doesn't work, the hidden_states will be nan after GemmaModel.forward

Question as the above title, some prompt it can work, for example, the default prompt " the meaning of the life", but the below prompt cannot work.

"the self-attention is important for transformer because"

some basic debug info is the below:
===DEBUG===: after model -> hidden_states = tensor([[[-10.6953, 3.7734, 0.0226, ..., -0.6284, -1.8652, -1.2998]]],
device='cuda:0', dtype=torch.float16)
===DEBUG===: hidden_states = tensor([[[-21.8125, -1.1279, 4.8867, ..., -6.3945, 0.7524, 4.8867]]],
device='cuda:0', dtype=torch.float16) kv_write_indices= tensor([12], device='cuda:0')
===DEBUG===: after model -> hidden_states = tensor([[[-6.8164, 2.2676, 0.6655, ..., 1.5391, -2.5996, -2.0840]]],
device='cuda:0', dtype=torch.float16)
===DEBUG===: hidden_states = tensor([[[-18.0000, -0.4390, 5.7070, ..., -5.7070, 1.7559, 0.8779]]],
device='cuda:0', dtype=torch.float16) kv_write_indices= tensor([13], device='cuda:0')
===DEBUG===: after model -> hidden_states = tensor([[[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',
dtype=torch.float16)
Traceback (most recent call last):

after running for a while, the hidden_states will become nan after GemmaModel.forward.

Unable to reproduce MATH resulst

Hi there, thanks for sharing gemma. But it seems that I can't reproduce the MATH 24% 4-shot accuracy. I'm only getting 20% now. Is there anyone trying to reproduce that? What's the prompt?

Is max_position_embeddings=8096 neccessary in 2b model?

I just try to do some small changes on model '2b'
1, Limit max_position_embeddings from 8096 to 256. :)
2, Trim kv-cache in GemmaAttention to max_position_embeddings(256).
3, Unlimit the output length of model.generate
The generate work is still working fine and can generate about 400 tokens for question "The life meaning is".

Is that means?
1, The too old kv-caches is not neccessary and the model can store and compress long-context info into 256 kv-cache(18-layers)?
2, Could have a try on training model this way(only max 256 kv-cache)?
3, If above is true, Does this means that we can decrease the training and generating complexity tremendously from O(LLD) to O(256LD) = O(L*D)?

H

Gg

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

I get error RuntimeError: probability tensor contains either inf, nan or element < 0 when infering with gemma-7b-quant.ckpt. The environments are as follows:

fairscale==0.4.13
filelock==3.9.0
fsspec==2023.4.0
gemma==0.1
immutabledict==4.1.0
Jinja2==3.1.2
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.24.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==8.7.0.84
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-nccl-cu11==2.19.3
nvidia-nvtx-cu11==11.8.86
pillow==10.2.0
sentencepiece==0.1.99
sympy==1.12
torch==2.2.1+cu118
torchaudio==2.2.1+cu118
torchvision==0.17.1+cu118
triton==2.2.0
typing_extensions==4.8.0

Cannot run on v4-16 worker 0 TPU VM: "Failed to get global TPU topology"

markusheimerl@t1v-n-a16d1e4e-w-0:~/gimli$ cd ~/gemma_cktp/ && curl -o archive.tar.gz "https://storage.googleapis.com/kaggle-models-data/5305/11357/bundle/archive.tar.gz?X-Goog-Algorithm=GOOG4-RSA-SHA256..." && tar -xf archive.tar.gz && cd ~/gimli
markusheimerl@t1v-n-a16d1e4e-w-0:~/gimli$ cd ../gemma_pytorch/
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ VARIANT=2b
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ CKPT_PATH=/home/markusheimerl/gemma_ckpt/
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ sudo usermod -aG docker $USER
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ newgrp docker
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ DOCKER_URI=gemma_xla:${USER}
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI}
[+] Building 109.0s (19/19) FINISHED                                                                                                          
 => [internal] load build definition from xla.Dockerfile                                                                                 0.0s
 => => transferring dockerfile: 1.36kB                                                                                                   0.0s
 => [internal] load .dockerignore                                                                                                        0.0s
 => => transferring context: 2B                                                                                                          0.0s
 => [internal] load metadata for us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128                   0.5s
 => [internal] load build context                                                                                                        0.1s
 => => transferring context: 6.49MB                                                                                                      0.1s
 => [ 1/14] FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128@sha256:5851322d5728f4b43f6f068f  45.8s
 => => resolve us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128@sha256:5851322d5728f4b43f6f068fa5c  0.0s
 => => sha256:577ff23cfe55ac8872bc433ce99971a34011e7a15f7c8afa3d6492c78d6d23e5 15.76MB / 15.76MB                                         0.5s
 => => sha256:5851322d5728f4b43f6f068fa5c69444db370f2cac8222183036666971f41846 4.12kB / 4.12kB                                           0.0s
 => => sha256:d1da99c2f14827498c4a9bb3623ae909b44564bdabad1802f064169069df81fb 55.06MB / 55.06MB                                         0.9s
 => => sha256:986e2cf4d9a25b7f49a2703932cad3cda01a6382bd6d38d902ad163bcc40af66 12.37kB / 12.37kB                                         0.0s
 => => sha256:c7b1e60e9d5a0f16eb1f998245666f7a64a44f8b1f2317bd31e8a658150c23d3 54.60MB / 54.60MB                                         1.3s
 => => sha256:beefab36cbfedf8896b5f9f0bc33336fa13c0f01a8cb2333128dd247895a5f3b 196.88MB / 196.88MB                                       3.3s
 => => extracting sha256:d1da99c2f14827498c4a9bb3623ae909b44564bdabad1802f064169069df81fb                                                1.1s
 => => sha256:de3224efe7269100000f1d5f451a8a6e5320b18160642c38bb97326907ecddea 6.29MB / 6.29MB                                           1.6s
 => => sha256:610099c6791eacee1e5a6d88d9ffc593485db3cc1e9bbdb9f7d112fa8fdd2725 17.54MB / 17.54MB                                         1.7s
 => => sha256:2c692cd1c1ae41798a701af5700224eea39e19736623388748e3a9a47782ba85 243B / 243B                                               1.8s
 => => sha256:67440d657e4fe10bca66a9fba87510371db3891e394766a7fdebaa8d19c6a062 2.85MB / 2.85MB                                           3.3s
 => => extracting sha256:577ff23cfe55ac8872bc433ce99971a34011e7a15f7c8afa3d6492c78d6d23e5                                                0.3s
 => => sha256:9e0ab466566e8a57436fd287a230ca562cc8aa0dd531504a2bae86f46c5d400a 130B / 130B                                               3.5s
 => => sha256:e47d3f9f3b24251c2ac7d5b04932ff23bf0099ec7c8676b6ebf86f061e1012fc 7.50kB / 7.50kB                                           3.5s
 => => sha256:9df685ee4175698291626f61c3afdd6618cdcd426b0273c4888b04607b972085 105.02MB / 105.02MB                                       4.9s
 => => sha256:3824c472a674d4e015ce79c8435392edc09924cca442f483835bbe9eae9ea52f 572.91MB / 572.91MB                                      11.1s
 => => sha256:03551a4901c735361e9e7e447828ecf77beeb2336f49fb788fd907cdb5fca972 153B / 153B                                               3.7s
 => => extracting sha256:c7b1e60e9d5a0f16eb1f998245666f7a64a44f8b1f2317bd31e8a658150c23d3                                                1.3s
 => => sha256:8ea82a97bc6ae43a7aed49d861ce05c3ed9757801016770a1101e784a5e5bc45 125.35MB / 125.35MB                                       5.9s
 => => sha256:d408b33f81ce05b78eed03e23d0081e7cdb3972c57c1103565f04f7332ed87fd 375.47MB / 375.47MB                                      15.7s
 => => sha256:3a387ede7ef122b7ad44078e16b8df873c87fa29cb1ef20e225b480be4769d34 201.39MB / 201.39MB                                      12.7s
 => => extracting sha256:beefab36cbfedf8896b5f9f0bc33336fa13c0f01a8cb2333128dd247895a5f3b                                                3.9s
 => => extracting sha256:de3224efe7269100000f1d5f451a8a6e5320b18160642c38bb97326907ecddea                                                0.2s
 => => extracting sha256:610099c6791eacee1e5a6d88d9ffc593485db3cc1e9bbdb9f7d112fa8fdd2725                                                0.4s
 => => extracting sha256:2c692cd1c1ae41798a701af5700224eea39e19736623388748e3a9a47782ba85                                                0.0s
 => => sha256:9407cf7758b440fb6f94e6159ac5e30436976a89995ea3f49bb22079ba9f206c 150B / 150B                                              15.9s
 => => sha256:7df537d35e3203cfb1a67c224e5b7f7769c6f47e7024705c26ce4a387402baad 653.48MB / 653.48MB                                      24.5s
 => => extracting sha256:67440d657e4fe10bca66a9fba87510371db3891e394766a7fdebaa8d19c6a062                                                0.2s
 => => extracting sha256:9e0ab466566e8a57436fd287a230ca562cc8aa0dd531504a2bae86f46c5d400a                                                0.0s
 => => extracting sha256:9df685ee4175698291626f61c3afdd6618cdcd426b0273c4888b04607b972085                                                4.9s
 => => extracting sha256:e47d3f9f3b24251c2ac7d5b04932ff23bf0099ec7c8676b6ebf86f061e1012fc                                                0.0s
 => => extracting sha256:3824c472a674d4e015ce79c8435392edc09924cca442f483835bbe9eae9ea52f                                                6.1s
 => => extracting sha256:03551a4901c735361e9e7e447828ecf77beeb2336f49fb788fd907cdb5fca972                                                0.0s
 => => extracting sha256:8ea82a97bc6ae43a7aed49d861ce05c3ed9757801016770a1101e784a5e5bc45                                                0.6s
 => => extracting sha256:3a387ede7ef122b7ad44078e16b8df873c87fa29cb1ef20e225b480be4769d34                                                0.9s
 => => extracting sha256:d408b33f81ce05b78eed03e23d0081e7cdb3972c57c1103565f04f7332ed87fd                                                7.6s
 => => extracting sha256:9407cf7758b440fb6f94e6159ac5e30436976a89995ea3f49bb22079ba9f206c                                                0.0s
 => => extracting sha256:7df537d35e3203cfb1a67c224e5b7f7769c6f47e7024705c26ce4a387402baad                                                3.1s
 => [ 2/14] RUN apt-get update                                                                                                          39.2s
 => [ 3/14] RUN apt-get install -y --no-install-recommends apt-utils                                                                     1.8s
 => [ 4/14] RUN apt-get install -y --no-install-recommends curl                                                                          2.6s
 => [ 5/14] RUN apt-get install -y --no-install-recommends wget                                                                          1.4s
 => [ 6/14] RUN apt-get install -y --no-install-recommends git                                                                           1.4s
 => [ 7/14] RUN python3 -m pip install --upgrade pip                                                                                     3.3s
 => [ 8/14] RUN pip install fairscale==0.4.13                                                                                            5.8s
 => [ 9/14] RUN pip install numpy==1.24.4                                                                                                1.4s
 => [10/14] RUN pip install immutabledict==4.1.0                                                                                         1.5s
 => [11/14] RUN pip install sentencepiece==0.1.99                                                                                        1.7s
 => [12/14] COPY . /workspace/gemma/                                                                                                     0.1s
 => [13/14] WORKDIR /workspace/gemma/                                                                                                    0.0s
 => [14/14] RUN pip install -e .                                                                                                         2.2s
 => exporting to image                                                                                                                   0.3s
 => => exporting layers                                                                                                                  0.3s
 => => writing image sha256:f256970b444877dc3e1bde548f82915bc2a8965542ef6e6e5c60c8f0497dfca1                                             0.0s
 => => naming to docker.io/library/gemma_xla:markusheimerl                                                                               0.0s
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ docker run -t --rm --shm-size 4gb -e PJRT_DEVICE=TPU -v ${CKPT_PATH}:/tmp/ckpt ${DOCKER_URI} python scripts/run_xla.py --ckpt=/tmp/ckpt --variant="${VARIANT}" --quant
usage: run_xla.py [-h] --ckpt CKPT [--variant {2b,7b}] [--output_len OUTPUT_LEN] [--seed SEED] [--quant] [--prompt PROMPT]
run_xla.py: error: argument --variant: invalid choice: '' (choose from '2b', '7b')
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ echo $VARIANT

markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ VARIANT=2b
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ CKPT_PATH=/home/markusheimerl/gemma_ckpt/
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ DOCKER_URI=gemma_xla:${USER}
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ docker run -t --rm --shm-size 4gb -e PJRT_DEVICE=TPU -v ${CKPT_PATH}:/tmp/ckpt ${DOCKER_URI} python scripts/run_xla.py --ckpt=/tmp/ckpt --variant="${VARIANT}" --quant
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
Error in atexit._run_exitfuncs:
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
Error in atexit._run_exitfuncs:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/site-packages/torch_xla/__init__.py", line 142, in _prepare_to_exit
    _XLAC._prepare_to_exit()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initialized
    _XLAC._prepare_to_exit()
    _XLAC._prepare_to_exit()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initializedRuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initialized

    _XLAC._prepare_to_exit()
RuntimeError: torch_xla/csrc/runtime/runtime.cc:17 : Check failed: !was_initialized 
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        torch_xla::runtime::GetComputationClient()



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyVectorcall_Call

        Py_FinalizeEx
        Py_Exit


        PyRun_SimpleStringFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***
ComputationClient already initialized
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 58, in _run_thread_per_device
    initializer_fn(local_rank, local_world_size)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 117, in initialize_multiprocess
    devices = xm.get_xla_supported_devices()
  File "/usr/local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 99, in get_xla_supported_devices
    xla_devices = _DEVICES.value
  File "/usr/local/lib/python3.8/site-packages/torch_xla/utils/utils.py", line 29, in value
    self._value = self._gen_fn()
  File "/usr/local/lib/python3.8/site-packages/torch_xla/core/xla_model.py", line 20, in <lambda>
    _DEVICES = xu.LazyProperty(lambda: torch_xla._XLAC._xla_get_devices())
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "scripts/run_xla.py", line 259, in <module>
    main(args)
  File "scripts/run_xla.py", line 231, in main
    xmp.spawn(
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 200, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/runtime.py", line 82, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 160, in run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 161, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/local/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/local/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: INTERNAL: Failed to get global TPU topology.
markusheimerl@t1v-n-a16d1e4e-w-0:~/gemma_pytorch$ 

not found weight file

Build the image according to the dockerfile file, then run the container. Error: IsADirectoryError: [Errno 21] Is a directory: '/tmp/ckpt', it should be that there is no weight file in the directory '/tmp/ckpt'. Did not a weight file be generated when building the image? How to solve this problem?

Inconsistencies in Reported Dimensions and Configuration Files

In Table 1 of Gemma Technical Report Feedforward hidden dims are listed as 32768 and 49152 for the 2B and 7B models, respectively. However, these figures do not align with the numbers provided in the configuration files for the for 7B model and 2B model. This discrepancy leads me to wonder whether I am comparing the incorrect figures, if there is an error in the report, or if the experiments were conducted using different configuration files. Should the numbers in the technical report require revision, it would also be necessary to update the reported total number of parameters accordingly.
image

Error when running Gemma inference on GPU

When I run

docker run -t --rm \
    --gpus all \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run.py \
    --device=cuda \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    --prompt="${PROMPT}"

It returns the error:
docker: Error response from daemon: could not select device drit device driver "" with capabilities: [[gpu]].

while if I run on CPU with command:

docker run -t --rm \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run.py \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    --prompt="${PROMPT}"

It works out OK.

--quant always returns True

Regardless of what value is passed in, the --quant argument will always return True. Python will cast the string to always true, even if you have set it to false:

$ docker run -t --rm     --gpus all     -v ${CKPT_PATH}:/tmp/ckpt     ${DOCKER_URI}     python scripts/run.py     --device=cuda     --ckpt=/tmp/ckpt  --quant="false" --variant="${VARIANT}"  --prompt="${PROMPT}"                                                                                                                  
True

>>> bool("foo")
True
>>> bool("false")
True
>>> bool("0")
True

Solution is to have a portable strtobool utility, and I suggest having a contract of something like "0", "1", "true", "false". Bash is stringly typed and Python is duck typed, so it leads to this kind of issues.

Are there reserved/unused tokens for developers?

Due to BPE vocabulary unable to dynamically expand after training, for finetuning, some BPE tokenizer based models such as Qwen reserved 2k extra unused tokens at the end for developers to use as they see fit.

Does Gemma have a list of internally unused tokens?

Sometimes model makers resize a vocab to a nice gpu-friendly multiple which creates unused tokens or intentially leave some unused tokens such as Qwen.

always loss nan while finetune a few step, wether fp32 or fp16

always loss nan while finetune a few step, wether fp32 or fp16, not stable or other question?

code:
https://github.com/yongzhuo/gemma-sft/blob/master/gemma_sft/ft_gemma/train.py

log:

{'loss': 5.6332, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.3049, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 153318.95, 'grad_norm': nan, 'learning_rate': 0.0002, 'epoch': 0.09}
{'loss': 5.7517, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.6231, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.0968, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.8938, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 6.1305, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.9105, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.4063, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.2371, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.4602, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.5166, 'lr': 0.0002, 'epoch': 0.09}
{'loss': 5.0093, 'lr': 0.0002, 'epoch': 0.09}
{'loss': nan, 'lr': 0.0002, 'epoch': 0.09}

Can't disable sampling

Currently, the generate() method doesn't seem to allow disabling sampling. The forward() method in the Sampler class performs greedy search if the temperatures argument is None but the GemmaForCausalLM's generate() method doesn't allow for setting the temperature argument to None because of this line -> https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L508. Also, setting the temperature to 0 fails with the following error RuntimeError: probability tensor contains either inf, nan or element < 0.

`torch.load` without `weights_only` parameter is unsafe

This is found via https://github.com/pytorch-labs/torchfix/

torch.load without weights_only parameter is unsafe. Explicitly set weights_only to False only if you trust the data you load and full pickle functionality is needed, otherwise set weights_only=True.

gemma/model.py:562:13

--- /home/sdym/repos/google/gemma_pytorch/gemma/model.py
+++ /home/sdym/repos/google/gemma_pytorch/gemma/model.py
@@ -557,9 +557,9 @@
         # If a string was provided as input, return a string as output.
         return results[0] if is_str_prompt else results
 
     def load_weights(self, model_path: str):
         self.load_state_dict(
-            torch.load(model_path, mmap=True)['model_state_dict'],
+            torch.load(model_path, mmap=True, weights_only=True)['model_state_dict'],
             strict=False,
         )

gemma/model_xla.py:517:22

--- /home/sdym/repos/google/gemma_pytorch/gemma/model_xla.py
+++ /home/sdym/repos/google/gemma_pytorch/gemma/model_xla.py
@@ -512,11 +512,11 @@
             top_ks=top_ks,
         )
         return next_tokens
 
     def load_weights(self, model_path: str):
-        checkpoint = torch.load(model_path)
+        checkpoint = torch.load(model_path, weights_only=True)
         model_state_dict = checkpoint['model_state_dict']
 
         num_attn_heads = self.config.num_attention_heads
         num_kv_heads = self.config.num_key_value_heads
         head_dim = self.config.head_dim

Output with higher max_length is repetition of base text

While generating any text with a specified value of max_length, the generated text keeps repeating several times until the output spans the value of max_length. An example of the above is using the following code

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
single_prompt_result = gemma_lm.generate("Keras is a", max_length=4096)
print(single_prompt_result)

As you can observe the sentence keeps repeating to span the max_length while it should ideally stop once it has written the base text.
image

The code was run on Kaggle with "gemma_2b_en" model
GPU - P100
To recreate the issue you can run the given code.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.