Git Product home page Git Product logo

passt's Introduction

PaSST: Efficient Training of Audio Transformers with Patchout

This is the implementation for Efficient Training of Audio Transformers with Patchout

Patchout significantly reduces the training time and GPU memory requirements to train transformers on audio spectrograms, while improving their performance.

Patchout works by dropping out some of the input patches during training. In either a unstructured way (randomly, similar to dropout), or entire time-frames or frequency bins of the extracted patches (similar to SpecAugment), which corresponds to rows/columns in step 3 of the figure below.

PaSST architecture

Inference or Embeddings pre-trained models

If you only want to use the embeddings generated by the pretrained models, use your own fine-tuning framework, or you need it only for inference, you can find a stripped down version of this repo here. The package follows HEAR 2021 NeurIPS Challenge API, and can be installed:

pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=hear21passt' 

This repo is a complete framework for training the models and fine-tuning pre-trained models on Audioset on downstream tasks.

Getting the logits from the pretrained models

from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
print(model.net) # the transformer network.

# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
    # example audio_wave of batch=3 and 10 seconds
    audio = torch.ones((3, 32000 * 10))*0.5
    audio_wave = audio.cuda()
    logits=model(audio_wave) 

Getting a pre-trained model for fine tuning

from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.

# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)
print(model.net) # the transformer network.


# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k

model.train()
model = model.cuda()

Setting up the experiments environment

This repo uses forked versions of sacred for configuration and logging, and pytorch-lightning for training.

For setting up Mamba is recommended and faster then conda:

conda install mamba -n base -c conda-forge

Now you can import the environment from environment.yml

mamba env create -f environment.yml

Now you have an environment named ba3l. Now install the forked versions of sacred and pl-lightning and ba3l.

# dependencies
conda activate ba3l
pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=ba3l'
pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=pytorch-lightning'
pip install -e 'git+https://github.com/kkoutini/[email protected]#egg=sacred' 

In order to check the environment we used in our runs, please check the environment.yml and pip_list.txt files. Which were exported using:

conda env export --no-builds | grep -v "prefix" > environment.yml
pip list > pip_list.txt

Getting started

Each dataset has an experiment file such as ex_audioset.py and ex_openmic.py and a dataset folder with a readme file. In general, you can prob the experiment file for help:

python ex_audioset.py help

you can override any of the configuration using the sacred syntax. In order to see the available options either use omniboard or use:

 python ex_audioset.py print_config

There are many pre-defined configuration options in config_updates.py. These include different models, setups etc... You can list these configurations with:

python ex_audioset.py print_named_configs

The overall configurations looks like this:

  ...
  seed = 542198583                  # the random seed for this experiment
  slurm_job_id = ''
  speed_test_batch_size = 100
  swa = True
  swa_epoch_start = 50
  swa_freq = 5
  use_mixup = True
  warm_up_len = 5
  weight_decay = 0.0001
  basedataset:
    base_dir = 'audioset_hdf5s/'     # base directory of the dataset, change it or make a link
    eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'
    wavmix = 1
    ....
    roll_conf:
      axis = 1
      shift = None
      shift_range = 50
  datasets:
    test:
      batch_size = 20
      dataset = {CMD!}'/basedataset.get_test_set'
      num_workers = 16
      validate = True
    training:
      batch_size = 12
      dataset = {CMD!}'/basedataset.get_full_training_set'
      num_workers = 16
      sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'
      shuffle = None
      train = True
  models:
    mel:
      freqm = 48
      timem = 192
      hopsize = 320
      htk = False
      n_fft = 1024
      n_mels = 128
      norm = 1
      sr = 32000
      ...
    net:
      arch = 'passt_s_swa_p16_128_ap476'
      fstride = 10
      in_channels = 1
      input_fdim = 128
      input_tdim = 998
      n_classes = 527
      s_patchout_f = 4
      s_patchout_t = 40
      tstride = 10
      u_patchout = 0
      ...
  trainer:
    accelerator = None
    accumulate_grad_batches = 1
    amp_backend = 'native'
    amp_level = 'O2'
    auto_lr_find = False
    auto_scale_batch_size = False
    ...

There are many things that can be updated from the command line. In short:

  • All the configuration options under trainer are pytorch lightning trainer api. For example, to turn off cuda benchmarking add trainer.benchmark=False to the command line.
  • models.net are the PaSST (or the chosen NN) options.
  • models.mel are the preprocessing options (mel spectrograms).

Training on Audioset

Download and prepare the dataset as explained in the audioset page The base PaSST model can be trained for example like this:

python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"

For example using only unstructured patchout of 400:

python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384  models.net.u_patchout=400  models.net.s_patchout_f=0 models.net.s_patchout_t=0 -p -m mongodb_server:27000:audioset21_balanced -c "Unstructured PaSST base"

Multi-gpu training can be enabled by setting the environment variable DDP, for example with 2 gpus:

 DDP=2 python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"

Pre-trained models

Please check the releases page, to download pre-trained models. In general, you can get a pretrained model on Audioset using

from models.passt import get_model
model  = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1,
                   fstride=10, tstride=10,input_fdim=128, input_tdim=998,
                   u_patchout=0, s_patchout_t=40, s_patchout_f=4)

this will get automatically download pretrained PaSST on audioset with with mAP of 0.476. the model was trained with s_patchout_t=40, s_patchout_f=4 but you can change these to better fit your task/ computational needs.

There are several pretrained models availble with different strides (overlap) and with/without using SWA: passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470. For example, In passt_s_swa_p16_s16_128_ap473: p16 mean patch size is 16x16, s16 means no overlap (stride=16), 128 mel bands, ap473 refers to the performance of this model on Audioset mAP=0.479.

In general, you can get a this pretrained model using:

from models.passt import get_model
passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16)

Using the framework, you can evaluate this model using:

python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473

Ensemble of these models are provided as well: A large ensemble giving mAP=.4956

python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many

An ensemble of 2 models with stride=14 and stride=16 giving mAP=.4858

python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_s16_14

As well as other ensembles ensemble_4, ensemble_5

Contact

The repo will be updated, in the mean time if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.

passt's People

Contributors

faroit avatar kkoutini avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.