Git Product home page Git Product logo

rse's Introduction

Residual Shuffle-Exchange Networks: Official TensorFlow Implementation

This repository contains the official TensorFlow implementation of the following paper:

Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences

by Andis Draguns, Emīls Ozoliņš, Agris Šostaks, Matīss Apinis, Kārlis Freivalds

[AAAI] [arXiv] [BibTeX]

Abstract: Attention is a commonly used mechanism in sequence processing, but it is of O(n²) complexity which prevents its application to long sequences. The recently introduced neural Shuffle-Exchange network offers a computation-efficient alternative, enabling the modelling of long-range dependencies in O(n log n) time. The model, however, is quite complex, involving a sophisticated gating mechanism derived from the Gated Recurrent Unit.

In this paper, we present a simple and lightweight variant of the Shuffle-Exchange network, which is based on a residual network employing GELU and Layer Normalization. The proposed architecture not only scales to longer sequences but also converges faster and provides better accuracy. It surpasses the Shuffle-Exchange network on the LAMBADA language modelling task and achieves state-of-the-art performance on the MusicNet dataset for music transcription while being efficient in the number of parameters.

We show how to combine the improved Shuffle-Exchange network with convolutional layers, establishing it as a useful building block in long sequence processing applications.

Introduction

Residual Shuffle-Exchange networks are a simpler and faster replacement for the recently proposed Neural Shuffle-Exchange network architecture. It has O(n log n) complexity and enables processing of sequences up to a length of 2 million symbols where standard methods fail (e.g., attention mechanisms). The Residual Shuffle-Exchange network can serve as a useful building block for long sequence processing applications.

Demo

Click the gif to see the full video on YouTube:

Preview of results

Our paper describes Residual Shuffle-Exchange networks in detail and provides full results on long binary addition, long binary multiplication, sorting tasks, the LAMBADA question answering task and multi-instrument musical note recognition using the MusicNet dataset.

Here are the accuracy results on the MusicNet transcription task of identifying the musical notes performed from audio waveforms (freely-licensed classical music recordings):

Model Learnable parameters (M) Average precision score (%)
cgRNN 2.36 53.0
Deep Real Network 10.0 69.8
Deep Complex Network 8.8 72.9
Complex Transformer 11.61 74.22
Translation-invariant net unknown 77.3
Residual Shuffle-Exchange network 3.06 78.02

Note: Our used model achieves state-of-the-art performance while being efficient in the number of parameters using the audio waveform directly compared to the previous state-of-the-art models that used specialised architectures with complex number representations of the Fourier-transformed waveform.

Here are the accuracy results on the LAMBADA question answering task of predicting a target word in its broader context (on average 4.6 sentences picked from novels):

Model Learnable parameters (M) Test accuracy (%)
Random word from passage - 1.6
Gated-Attention Reader unknown 49.0
Neural Shuffle-Exchange network 33 52.28
Residual Shuffle-Exchange network 11 54.34
Universal Transformer 152 56.0
Human performance - 86.0
GPT-3 175000 86.4

Note: Our used model works faster and can be evaluated on 4 times longer sequences using the same amount of GPU memory compared to the Shuffle-Exchange network model and on 128 times longer sequences than the Universal Transformer model.

What are Residual Shuffle-Exchange networks?

Residual Shuffle-Exchange networks are a lightweight variant of the continuous, differentiable neural networks with a regular-layered structure consisting of alternating Switch and Shuffle layers that are Shuffle-Exchange networks.

The Switch Layer divides the input into adjacent pairs of values and applies a Residual Switch Unit, a learnable 2-to-2 function, to each pair of inputs producing two outputs, employing GELU and Layer Normalization.

Here is an illustration of a Residual Switch Unit, which replaces the Switch Unit from Shuffle-Exchange networks:

The Shuffle Layer follows where inputs are permuted according to a perfect-shuffle permutation (i.e., how a deck of cards is shuffled by splitting it into halves and then interleaving them) – a cyclic bit shift rotating left in the first part of the network and (inversely) rotating right in the second part.

The Residual Shuffle-Exchange network is organized in blocks by alternating these two kinds of layers in the pattern of the Beneš network. Such a network can represent a wide class of functions including any permutation of the input values.

Here is an illustration of a whole Residual Shuffle-Exchange network model consisting of two blocks with 8 inputs:

Running the experiments

Running the experiments requires the dependencies to be installed and the following system requirements.

System requirements

  • Python 3.6 or higher.
  • TensorFlow 1.14.0.

Training

To start training the Residual Shuffle-Exchange network, run the terminal command:

python3 trainer.py

By default it will train on the music transcription task. To select the sequence processing task for which to train the Residual Shuffle-Exchange network, edit the config.py file that contains various hyperparameter and setting options.

Music transcription

For the MusicNet music transcription task, make sure that the corresponding settings in config.py are uncommented:

"""Recommended settings for MusicNet"""
# task = "musicnet"
# n_Benes_blocks = 2  # depth of the model
...

To train the model on the MusicNet dataset, the dataset has to be downloaded and parsed - that can be done by running:

python3 musicnet_data/get_musicnet.py
python3 musicnet_data/parse_file.py

This might take a while. If you run out of RAM (it can take more than 40GB), you can download musicnet.npz from Kaggle and place it in the musicnet_data directory.

If you have enough RAM to load the entire dataset (can be more than 128GB), set musicnet_subset to False for faster training. Increasing musicnet_window_size requires more RAM and trains slower but produces greater accuracy.

To use a pretrained model for music transcription, place the contents of trained_model_m8192F1 in the out_dir directory specified in the config.py file.

To test the trained model for the MusicNet task on the test set, run tester.py. To transcribe a custom wav file to MIDI, place the file in the musicnet_data directory and run:

python3 transcribe.py yourwavfile.wav

LAMBADA task

For the LAMBADA question answering task uncomment the corresponding settings in config.py:

"""Recommended settings for lambada"""
# task = "lambada"
# n_input = lambada_vocab_size
...

To download the LAMBADA dataset see the original publication by Paperno et al.

To download the pre-trained fastText 1M English word embedding see the downloads section of the FastText library website and extract to directory listed in the config.py file variable base_folder under “Embedding configuration”:

"""Embedding configuration"""
use_pre_trained_embedding = False
base_folder = "/host-dir/embeddings/"
embedding_file = base_folder + "fast_word_embedding.vec"
emb_vector_file = base_folder + "emb_vectors.bin"
emb_word_dictionary = base_folder + "word_dict.bin"
...

To enable the pre-trained embedding change the config.py file variable use_pre_trained_embedding to True.

Windows

If you are running Windows, before starting training the Residual Shuffle-Exchange network edit the config.py file to change the directory-related variables to Windows file path format in the following way:

...
"""Local storage (checkpoints, etc)"""
...
out_dir = ".\host-dir\gpu" + gpu_instance
model_file = out_dir + "\\varWeights.ckpt"
image_path = out_dir + "\\images"
...

If you are doing music transcription on Windows, directory-related variables in files related to MusicNet would need to be changed in a similar manner.

Citing Residual Shuffle-Exchange networks

If you use Residual Shuffle-Exchange networks, please use the following BibTeX entry when citing the paper:

@inproceedings{draguns2021residual,
  title={Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences},
  author={Draguns, Andis and Ozoli{\c{n}}{\v{s}}, Em{\=\i}ls and {\v{S}}ostaks, Agris and Apinis, Mat{\=\i}ss and Freivalds, Karlis},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={35},
  number={8},
  pages={7245--7253},
  year={2021}
}

Contact information

For help or issues using Residual Shuffle-Exchange networks, please submit a GitHub issue.

For personal communication related to Residual Shuffle-Exchange networks, please contact Kārlis Freivalds ([email protected]).

rse's People

Contributors

andisdraguns avatar karlisfre avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.