Git Product home page Git Product logo

rho-loss's Introduction

This is the code for the paper "Prioritized training on points that are learnable, worth learning, and not yet learned".

The code uses PyTorch Lightning, Hydra for config file management, and Weights & Biases for logging. The codebase is adapted from this great template.

Installing dependencies

Conda: conda install --file my_environment.yaml

Poetry: poetry install

The repository also contains a singularity container definition file that can be built and used to run the experiments. See the singularity folder.

Tutorial

tutorial.ipynb contains the full training pipeline (irreducible loss model training and target model training) on CIFAR-10. This is the best place to start if you want to understand the code or reproduce our results.

Codebase

The codebase contains the functionality for all the experiments in the paper (and more ๐Ÿ˜œ).

Irreducible loss model training

Start with run_irreducible.py(which then calls src/train_irreducible.py). The base config file is configs/irreducible_training.yaml.

Target model training

Start with run.py(which then calls src/train.py). The base config file is configs/config.yaml. A key file is src//models/MultiModels.py---this is the LightningModule that handles the training loop incl. batch selection.

More about the code

The datamodules are implemented in src/datamodules/datamodules.py, the individual datasets in src/datamodules/dataset/sequence_datasets. If you want to add your own dataset, note that __getitem__() needs to return the tuple (index, input, target), where index is the index of the datapoint with respect to the overall dataset (this is required so that we can match the irreducible losses to the correct datapoints).

All the selection methods mentioned in the paper (and more) are implemented in src/curricula/selection_methods.py.

ALBERT fine-tuning

All ALBERT experiments are implemented in a separate branch, which is a bit less clean. Good luck :-)

Reproducibility

This repo can be used to reproduce all the experiments in the paper. Check out configs/experiment for some example experiment configs. The experiment files for the main results are:

  • CIFAR-10: cifar10_resnet18_irred.yaml and cifar10_resnet18_main.yaml
  • CINIC-10: cinic10_resnet18_irred.yaml and cinic10_resnet18_main.yaml
  • CIFAR-100: cifar100_resnet18_irred.yaml and cifar100_resnet18_main.yaml
  • Clothing-1M: c1m_resnet18_irred.yaml and c1m_resnet50_main.yaml

NLP datasets, on a separate branch:

  • CoLA:
    • Irreducible loss model training: python run_irreducible_nlp.py +experiment=nlp trainer.max_epochs=10 callbacks=val_loss datamodule.task_name=sst2 trainer.val_check_interval=0.05
    • Target model training: python run_nlp.py +experiment=nlp datamodule.task_name=cola trainer.max_epochs=100 irreducible_loss_generator.f=\"path/to/file" selection_method_nlp=reducible_loss_selection
  • SST2:
    • Irreducible loss model training: python run_irreducible_nlp.py +experiment=nlp trainer.max_epochs=10 callbacks=val_loss datamodule.task_name=sst2 trainer.val_check_interval=0.05
    • Target model training: python run_nlp.py +experiment=nlp trainer.max_epochs=15 datamodule.task_name=sst2 +trainer.val_check_interval=0.2 irreducible_loss_generator.f=\"path/to/file" selection_method_nlp=reducible_loss_selection

Notes on using the importance sampling baseline:

To run the importance sampling experiments:

Importance sampling on CINIC10

python3 run_simple.py datamodule.data_dir=$DATA_DIR +experiment=importance_sampling_baseline.yaml 

rho-loss's People

Contributors

mtrazzak avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.