Git Product home page Git Product logo

fine-tuning-gpt2's Introduction

MDS5210-23fall

Note:

โš 

  • Since the code is runned on the computation node (single GPU version) on clusters, thus, some codes are ignorable when being runned on your server or kaggle
  • Some model paths should be modified according to your local environments

Process:

Base Tasks

4.1 (A)

  • Run prepare_sft_dataset.py to generate two files: sft_train.json and sft_test.json
  • Train gpt2-medium
    Setting Value
    Model gpt2-medium
    Train iteration 160000
    Batch size 8
    Optimizer AdamW (weight decay)
    Test step 200 steps
    Test data 240
    Other Hyper-parameter Here
    Training Record Here
    Train error Here
    Test error Here

โ— Train 20000 (Train iteration / batch size) steps, and test in every 200 steps

4.1(B)

  • Run eval.py to evaluate the performance of vanilla gpt2-medium and sft gpt2-medium, result detail

๐Ÿš€ eval.py (modified on evaluate.py) leverages the reward model OpenAssistant/rewardmodel-deberta-v3-large-v2 to evaluate the performance instead of using Openai Apikey

4.1(C)

  • Insights: Summarize what you find based on the results or settings from 4.1(A) and 4.1(B).

Explorations

4.2(C)

  • Run train_sft.py. The code of lora has been integrated by TA, only change cfg = get_configs("gpt2-medium") to cfg = get_configs("gpt2-medium/lora") in train_sft.py to add lora on gpt2-medium
  • Comparsion
    Setting Figure Link
    gpt2-medium train error Here
    gpt2-medium test error Here
    gpt2-medium/lora train error Here
    gpt2-medium/lora test error Here

โ— gpt2-medium and gpt2-medium/lora are trained based on same hyper-parameter settings, optimizer: AdamW (weight decay)

Lora Rank Figure Link Dialogue Quality
Full paramters Here $0.85$
lora rank = 1 Here $0.51$
lora rank = 10 Here $0.45$
lora rank = 100 Here $0.49$

โ— Different lora rank training are based on same hyper-parameter settings, optimizer: AdamW (weight decay)

4.2(A)

  • Run train_sft.py. Only need to switch the optimizer from AdamW to others (already included in fit() function in trainers.py) and then test on different optimizers
  • Comparsion
    Optimizers Figure Link GPU Memory
    SGD Here $1663992832$ bytes
    SGD with Momentum (momentum=0.9) Here $1877104128$ bytes
    SGD with Nesterov (momentum=0.9) Here $1877104128$ bytes
    AdamW ($\beta_1=0.9$, $\beta_2=0.95$) Here $2090215424$ bytes

โ— Models with different optimizers are trained with the same weight decay and hyper-parameter settings

โ— 4.2(A) experiments are conducted based on gpt2-medium/lora with rank $1$ (save time). In the report this should be specified explitcitly

fine-tuning-gpt2's People

Contributors

roy-mzh avatar ben3892 avatar ledzy avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.