Git Product home page Git Product logo

kangaroo's Introduction

Kangaroo

 Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting

| Arxiv Paper |

Version Contributions welcome


Drawing inspiration from early exiting, we propose a novel self-speculative decoding framework Kangaroo, which uses a fixed shallow sub-network as a self-draft model, with the remaining layers serving as the larger target model. We train a lightweight and efficient adapter module on top of the sub-network to bridge the gap between the sub-network and the full model’s representation ability. The adapter network consists of only one multi-head attention and two normalization layers. Surprisingly, we find this simple design efficient but powerful. To further reduce the inference latency of the self-draft model, we introduce an additional early exiting mechanism for generating draft tokens, aiming to avoid unnecessary costs on more difficult tokens.

TODO List

  • inference code & checkpoints of Kangaroo.
  • code for training Kangaroo.
  • tree verification.
  • bsz > 1 and decoding with sampling.

Training

We follow the training procedure of Medusa and Eagle.

  1. data preprocess
cd data
python allocation.py --outdir /home/ma-user/work/Data/
  1. training
python start_train.py

Inference

## Vicuna-7B as an example

## Vanilla decoding
CUDA_VISIBLE_DEVICES=0 python -m evaluation.inference_baseline --model-path "/cache/CKPT/vicuna-7b-v1.3" --model-id "vicuna-7b-v1.3-vanilla-float16-temp-0.0" --bench-name "Kangaroo" --temperature 0.0 --dtype "float16"

## Kangaroo
CUDA_VISIBLE_DEVICES=0 python -m evaluation.inference_kangaroo --adapter-path "/cache/CKPT/kangaroo-vicuna-7b-v1.3" --exitlayer 2 --model-path "/cache/CKPT/vicuna-7b-v1.3" --threshold 0.6 --steps 6 --model-id "vicuna-7b-v1.3-kangaroo-thres-0.6-steps-6-float16" --bench-name "Kangaroo" --dtype "float16"

To get the detailed speed information, run python evaluation/speed.py.

The corresponding huggingface ckpts of kangaroo can be downloaded at Kangaroo Google Drive.

Citation

@article{liu2024kangaroo,
  title={Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting},
  author={Liu, Fangcheng and Tang, Yehui and Liu, Zhenhua and Ni, Yunsheng and Han, Kai and Wang, Yunhe},
  journal={arXiv preprint arXiv:2404.18911},
  year={2024}
}

Acknowledgements

We acknowledge the authors of

License

License: MIT

kangaroo's People

Contributors

equationliu avatar

Stargazers

Steve Laskaridis avatar Jeff Carpenter avatar flame avatar Junlin Chen avatar  avatar Wenyue Hua avatar XinlongYang avatar  avatar x_zhang avatar  avatar 崔文耀 avatar shibu avatar Wayne-Ho avatar DaHoon Park avatar 爱可可-爱生活 avatar wailord avatar  avatar hflserdaniel avatar Junhao Wang avatar Yunsheng Ni avatar  avatar  avatar Eric Buehler avatar  avatar kyle avatar Kaiqi Chen avatar Daxiong avatar Shyam Peri avatar  avatar Vaibhav Bansal avatar Wei avatar Bradley McDanel avatar  avatar Clay avatar  avatar 唐国梁Tommy avatar Mike Bybee avatar Less Wright avatar masa-erland avatar

Watchers

 avatar Less Wright avatar han avatar

kangaroo's Issues

In line 263 of train.py, predict = model(inputs_embeds=data["hidden_states_early"]

Hello author! Thank you for your excellent work!

In line 263 of train.py, predict = model(inputs_embeds=data["hidden_states_early"], attention_mask=data["attention_mask"]),
report error "model has no attribute inputs.embeds"
because the class AdapterModel in adapter.py does not define a forward(), only the forward_early_stop()
Can you add a forward()?

Training procedure of Kangaroo.

Hi Equationliu and all authors,
Thanks for your brilliant work in proposing this promising "Kangaroo" method.
I am curious about training details of Kangaroo. Could you plz provide the training code of Kangaroo~~?
Thanks in advance.

a question

Hi author, during the training phase, does it require a large amount of physical memory to save the hidden state in ckpt format in the middle? On average, a single training data requires about 10MB, while a complete training dataset may take several terabytes. If the device's memory is insufficient, do you have any suggestions to provide? Thank you very much!

Encountering NaN output at a specific batch ID every run, and no change observed upon adjusting the learning rate

Subject

Encountering NaN output at a specific batch ID every run, and no change observed upon adjusting the learning rate

Detailed Description

I downloaded a code repository for deep learning training from GitHub and attempted to train my model using it. Unfortunately, I've encountered an issue where the loss outputs NaN consistently at the same batch ID each time I run the training. This occurs regardless of how I change the input data or initialization states.

Additionally, I tried adjusting the learning rate to address this issue, but curiously, there was no observable change— the loss remained unchanged. I have confirmed that the learning rate changes are correctly accepted and set in the code, but the problem persists.

Request for Help

I would like to understand other possible causes for this issue and if there are recommended debugging strategies or solutions. Additionally, if other developers have encountered similar problems and found solutions, I would greatly appreciate it if you could share them.

Thank you for your time and assistance!

Kangaroo when bsz is greater than 1.

Hello, I would like to ask how Kangaroo works in scenarios where bsz is greater than 1, and which parts of the code need to be modified. Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.