Git Product home page Git Product logo

amazon-science / mada_optimizer_search Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 1.0 4.24 MB

Code the ICML 2024 paper: "MADA: Meta-Adaptive Optimizers through hyper-gradient Descent"

Home Page: https://arxiv.org/abs/2401.08893

License: Apache License 2.0

Python 100.00%
adam-optimizer deep-neural-networks gpt-2 large-language-models machine-learning machine-learning-algorithms meta-optimizer optimization optimization-algorithms

mada_optimizer_search's Introduction

MADA: Meta-Adaptive Optimizers through hyper-gradient Descent

Authors: Kaan Ozkara, Can Karakus, Parameswaran Raman, Mingyi Hong, Shoham Sabach, Branislav Kveton, Volkan Cevher

This repository includes the code to simulate experiments for our paper [MADA: Meta Adaptive Momentum Estimates through Hypergradient Descent] (https://arxiv.org/abs/2401.08893). The GPT training code is based on nanoGPT by Andrej Karpathy (https://github.com/karpathy/nanoGPT). Meta optimizer implementation is inspired by (https://github.com/kach/gradient-descent-the-ultimate-optimizer/tree/main).

./config includes configuration files that controls the parameters in the code.

./results includes some of the results that were mentioned in the quip document for the project.

./gdtuo.py is the implementation of meta optimizer through hypergradient descent.

./model.py includes a generic GPT-2 type implementation from nanoGPT.

./plot... .py files are used to plot the results that are in ./results.

train.py, train_ddp.py, toy.py, toy2.py, includes the files to run experiments.

train_ddp.py is the latest run file and has from scratch supoorts for ddp, gradient_accumulation.

Example run:

python train_ddp.py config/train_gpt2_small.py --dtype='float32' --beta1=0.9 --beta2=0.95 --beta3=0.0 --rho=0.6 --c=1.0 --gamma=1.0

The arguments here refer to the initial values of the optimizer parameters. Additional variables about the nanoGPT run can also be included if needed for e.g. to determine logging, grad accumulation and so on. At the moment, to change the hypergradient hyperparameters (such as learning rate) and ddp size one, please update the code. The output directory to save log files is set as a FSx directory and would need to be changed inside the code as well. There are two types of logging, the first one where for every log_iter the optimizer parameters, training loss and validation loss are logged. The second one is logging at the end of run.

Citation

Please consider citing our paper if you use our code:

@misc{ozkara2024mada,
      title={MADA: Meta-Adaptive Optimizers through hyper-gradient Descent}, 
      author={Kaan Ozkara and Can Karakus and Parameswaran Raman and Mingyi Hong and Shoham Sabach and Branislav Kveton and Volkan Cevher},
      year={2024},
      eprint={2401.08893},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Security

See CONTRIBUTING for more information.

License

This project is licensed under the Apache-2.0 License.

mada_optimizer_search's People

Contributors

amazon-auto avatar paramsraman avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

evdcush

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.