Git Product home page Git Product logo

gemma-easylm's Introduction

Gemma-EasyLM

This document outlines the integration of the Gemma model into the EasyLM framework, including instructions for training, converting the model format, and serving the model with Gradio.

Training: Integrating HF Flax Weights into EasyLM

Step 1: Consolidate Flax Weights from Hugging Face

You can skip this step with downloading https://huggingface.co/beomi/gemma-ko-7b/resolve/flax-init/flax_model.msgpack

Firstly, concatenate all Flax model weights available at: Hugging Face - Gemma 7B.

Use the following example code to accomplish this:

from transformers import GemmaForCausalLM

model = GemmaForCausalLM.from_pretrained("google/gemma-7b", torch_dtype="auto")
model.save_pretrained("./flax-concatted", max_shard_size="99GB")

This script generates a flax-concatted/flax_model.msgpack file. We will utilize this .msgpack file during the training process.

Step 2: Upload the .msgpack File to Google Cloud Storage (GCS)

Execute the following command to upload the generated .msgpack file to your GCS repository:

gsutil cp ./flax-concatted/flax_model.msgpack gs://YOUR_GCS_REPO_NAME

Step 3: Modify the train.sh Script

Adjust the paths for load_checkpoint, train_dataset.json_dataset.path, and logger.output_dir within the train.sh script to match your setup.

The provided example train.sh script assumes training will be conducted on a TPUv4-64 pod slice.

Step 4: Initiate Training

Execute the training script to start the training process:

./train.sh

Conversion: From EasyLM to Hugging Face Format

Step 1: Retrieve the streaming_train_state File

Download the streaming_train_state file from your GCS repository using the following command:

gsutil cp gs://YOUR_GCS_REPO_NAME/.../streaming_train_state_80000 .

Note: The file name will either be streaming_train_state or streaming_train_state_STEPNO.

Step 2: Update the .stream File Path

In the convert_easylm_stream_to_hf_safetensors.py file, modify the path to the .stream file accordingly:

# Modify this line
_, param = StreamingCheckpointer.load_trainstate_checkpoint(load_from='trainstate_params::/home/latheledusjp/streaming_train_state_80000')

Step 3: Execute the Conversion Script

Run the conversion script to convert the EasyLM model format to Hugging Face's format:

python convert_easylm_stream_to_hf_safetensors.py

Step 4: Verify the Output Files

Check the generated output files in the ./gemma-ko-8.5b-dev directory.

The Flax-version of the weight file can be found in the ./flax-gemma-ko-8b folder.

Serving the Model with Gradio

To serve the model using Gradio, follow these steps:

cd EasyLM/models/gemma
pip install -r serving_requirements.txt
./serve_test.sh

Original EasyLM Reference

If you found EasyLM useful in your research or applications, please cite using the following BibTeX:

@software{geng2023easylm,
  author = {Geng, Xinyang},
  title = {EasyLM: A Simple And Scalable Training Framework for Large Language Models},
  month = March,
  year = 2023,
  url = {https://github.com/young-geng/EasyLM}
}

Credits

  • The LLaMA implementation is from JAX_llama
  • The JAX/Flax GPT-J and RoBERTa implementation are from transformers
  • Most of the JAX utilities are from mlxu
  • The codebase is heavily inspired by JAXSeq

gemma-easylm's People

Contributors

akhilkedia avatar beomi avatar gianlucadetommaso avatar juliensalinas avatar lhao499 avatar supermdguy avatar syzymon avatar wthoutanymmries avatar young-geng avatar yulv-git avatar zyhowell avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

gemma-easylm's Issues

Gemma Approx GELU Issue

준범님 안녕하세요 좋은 모델 공유 감사합니다. 공유해주신 2B, 7B 모델을 파인튜닝해서 써보려고 했는데요...!
EasyDeL 이라는 라이브러리 써서 파인튜닝 중인데 GeLU 관련 이슈가 있어서 남겨봅니다.

기존 transformers의 Gemma 구현체 버전에서 GeLU를 사용했는데, 실제로는 bfloat16 오차로 인해 Approximated GeLU (gelu_pytorch_tanh)를 사용해야 하는 버그가 있었습니다. huggingface/transformers#29402

그런데 이 레포 코드를 보니까 현재 학습 후 공유해주신 모델은 GeLU로 튜닝하신 것 같아요
이 경우 최신 transformers에서 config.json 에 hidden_activation=gelu 가 없으면 자동으로 gelu_pytorch_tanh 로 변환되는 것 같습니다 (https://github.com/huggingface/transformers/blob/3b8e2932ce743008f63585aae1e1b8b30dc8b3ac/src/transformers/models/gemma/modeling_gemma.py#L175)
그래서 기존의 공유해주신 두 모델의 config.json에 이 부분이 추가되어야 할 것 같구요

그리고 continual pretraining 과정에서 GeLU로 했으면 오차로 인해 뭔가 문제가 있을수도 있지 않나 싶습니다...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.