Git Product home page Git Product logo

s4torch's Introduction

S4 Torch

A PyTorch implementation of Structured State Space for Sequence Modeling (S4), based on the beautiful Annotated S4 blog post and JAX-based library by @srush and @siddk.

Installation

pip install git+https://github.com/TariqAHassan/s4torch

If you wish to perform development and/or use wavelet-based transforms, you will need to install the development requirements. This can be done with:

pip install -r dev_requirements.txt

Requires Python 3.9+.

Quick Start

The S4Model() provides a high-level implementation of the S4 model, as illustrated below.

import torch
from s4torch import S4Model

N = 32
d_input = 1
d_model = 128
n_classes = 10
n_blocks = 3
seq_len = 784

u = torch.randn(1, seq_len, d_input)

s4model = S4Model(
    d_input,
    d_model=d_model,
    d_output=n_classes,
    n_blocks=n_blocks,
    n=N,
    l_max=seq_len,
    collapse=True,  # average predictions over time prior to decoding
)
assert s4model(u).shape == (u.shape[0], n_classes)

Training

Models can be trained using the command line interface (CLI) provided by train.py.
CLI documentation can be obtained by running python train.py --help.

Notes:

  • development requirements must be installed prior to training. This can be accomplished by running pip install -r dev_requirements.txt.
  • average pooling after each S4 block is used in some training sessions described below, whereas the original S4 implementation only uses average pooling prior to decoding. The primary motivation for additional pooling was to reduce memory usage and, at least in the case of Sequential MNIST, does not appear reduce accuracy. These additional pooling layers can be disabled by setting --pooling=None, or by simply omitting the --pooling flag.
  • specifying --batch_size=-1 will result in the batch size being auto-scaled
  • all experiments were performed on a machine with 8 CPU cores, 30 GB of RAM and a single NVIDIA® Tesla® V100 GPU with 16 GB of vRAM

Sequential MNIST

python train.py \
  --dataset=smnist \
  --batch_size=16 \
  --max_epochs=100 \
  --lr=1e-2 \
  --n_blocks=6 \
  --d_model=128 \
  --norm_type=layer

Validation Accuracy: 98.6% after 5 epochs, 99.3% after 9 epochs (best)
Speed: ~10.5 batches/second

python train.py \
  --dataset=smnist \
  --batch_size=16 \
  --pooling=avg_2 \
  --max_epochs=100 \
  --lr=1e-2 \
  --n_blocks=6 \
  --d_model=128 \
  --norm_type=layer

Validation Accuracy: 98.4% after 5 epochs, 99.3% after 10 epochs (best)
Speed: ~11.5 batches/second

Permuted MNIST

python train.py \
  --dataset=pmnist \
  --batch_size=16 \
  --pooling=avg_2 \
  --max_epochs=100 \
  --lr=1e-2 \
  --n_blocks=6 \
  --d_model=128 \
  --norm_type=layer

Validation Accuracy: 94.0% after 5 epochs, 96.2% after 18 epochs (best)
Speed: ~11.5 batches/second

Sequential CIFAR10

python train.py \
  --dataset=scifar10 \
  --batch_size=32 \
  --max_epochs=200 \
  --lr=1e-2 \
  --n_blocks=6 \
  --pooling=avg_2 \
  --d_model=1024 \
  --weight_decay=0.01 \
  --p_dropout=0.25 \
  --patience=20

Validation Accuracy: 75.0% after 8 epochs, 79.3% after 15 epochs (best)
Speed: ~1.6 batches/second

python train.py \
  --dataset=speech_commands10 \
  --batch_size=-1 \
  --max_epochs=150 \
  --lr=1e-2 \
  --n_blocks=6 \
  --pooling=avg_2 \
  --d_model=128 \
  --weight_decay=0.0 \
  --norm_type=batch \
  --norm_strategy=post \
  --p_dropout=0.1 \
  --patience=10

Validation Accuracy: 93.2% after 5 epochs, 95.8% after 13 epochs (best)
Speed: ~2.1 batches/second

Notes:

  • the speech_commands10 dataset uses a subset of 10 speech commands, as in the original implementation of S4. If you would like to train against all speech commands, the speech_commands dataset can be used instead.
  • Batch normalization appears to work best with a "post" normalization strategy, whereas a "pre" normalization strategy appears to work best with layer normalization.
Raw Waveform
python train.py \
  --dataset=nsynth_short \
  --batch_size=-1 \
  --val_prop=0.01 \
  --max_epochs=150 \
  --limit_train_batches=0.025 \
  --lr=1e-2 \
  --n_blocks=4 \
  --pooling=avg_2 \
  --d_model=128 \
  --weight_decay=0.0 \
  --norm_type=batch \
  --norm_strategy=post \
  --p_dropout=0.1 \
  --precision=16 \
  --accumulate_grad=4 \
  --patience=10

Validation Accuracy: 39.6% after 5 epochs, 54.1% after 17 epochs (best)
Speed: ~1.6 batches/second

Notes:

  • The model is tasked with classifying waveforms based on the musical instrument which generated them (10 classes)
  • The nsynth_short dataset contains waveforms which are truncated after 2 seconds, whereas the nsynth dataset contains the full four-second waveforms.
Continuous Wavelet Transform (|CWT(x)|)
python train.py \
  --dataset=nsynth_short \
  --batch_size=-1 \
  --val_prop=0.01 \
  --max_epochs=150 \
  --limit_train_batches=0.025 \
  --lr=1e-2 \
  --n_blocks=6 \
  --pooling=avg_2 \
  --d_model=100 \
  --weight_decay=0.0 \
  --norm_type=batch \
  --norm_strategy=post \
  --p_dropout=0.1 \
  --precision=16 \
  --accumulate_grad=1 \
  --wavelet_tform=True \
  --patience=10

Validation Accuracy: 52.7% after 5 epochs, 69.4% after 72 epochs (best)
Speed: ~1.3 batches/second

Notes:

  • This experiment uses the magnitude of the CWT (with a morlet wavelet) as the input representation. This produces a (rather substantial) 15%+ increase in performance.
  • This requires that you have pycwt installed. See the Installation instructions above.

Components

Layer

The S4Layer() implements the core logic of S4.

import torch
from s4torch.layer import S4Layer

N = 32
d_model = 128
seq_len = 784

u = torch.randn(1, seq_len, d_model)

s4layer = S4Layer(d_model, n=N, l_max=seq_len)
assert s4layer(u).shape == u.shape

Block

The S4Block() embeds S4Layer() in a commonplace processing "pipeline", with an activation function, dropout, linear layer, skip connection and layer normalization. (S4Model(), above, is composed of these blocks.)

import torch
from s4torch.block import S4Block

N = 32
d_input = 1
d_model = 128
d_output = 128
seq_len = 784

u = torch.randn(1, seq_len, d_model)

s4block = S4Block(d_model, n=N, l_max=seq_len)
assert s4block(u).shape == u.shape

References

The S4 model was developed by Albert Gu, Karan Goel, and Christopher Ré. If you find the S4 model useful, please cite their impressive paper:

@misc{gu2021efficiently,
    title={Efficiently Modeling Long Sequences with Structured State Spaces}, 
    author={Gu, Albert and Goel, Karan and R{\'e}, Christopher},
    year={2021},
    eprint={2111.00396},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Also consider checking out their fantastic repository at github.com/HazyResearch/state-spaces.

s4torch's People

Contributors

tariqahassan avatar

Stargazers

Muhiddin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.