Git Product home page Git Product logo

tunet's Introduction

TUNet - Official Implementation

TUNet: A Block-online Bandwidth Extension Model based on Transformers and Self-supervised Pretraining

Generic badge PWC Generic badge Generic badge

License and citation

This code is available for academic research only. If you use our software, please cite as below. For commercial applications, please contact [email protected].

Copyright © 2021 FPT Software, Inc. All rights reserved.

@misc{nguyen2021tunet,
      title={TUNet: A Block-online Bandwidth Extension Model based on Transformers and Self-supervised Pretraining}, 
      author={Viet-Anh Nguyen and Anh H. T. Nguyen and Andy W. H. Khong},
      year={2021},
      eprint={2110.13492},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

1. Results

Our model achieved a significant gain over baselines. Here, we include the predicted mean-opion-score (MOS) using Microsoft's DNSMOS Azure service. Please refer to our paper for more benchmarks.

Model DNSMOS
Input 3.0951
TFiLM-UNet 3.1026
WSRGlow 3.2053
NU-Wave 3.2760
TUNet 3.3896

We also provide several audio samples in audio_samples for comparison. In spectrogram visualization, it can be seen that high frequencies generated by our models are more accurate than the baselines.

2. Installation

Setup

Clone the repo

$ git clone https://github.com/NXTProduct/TUNet.git
$ cd TUNet

Install dependencies

  • Our implementation requires the libsndfile and libsamplerate libraries for the Python packages soundfile and samplerate, respectively. On Ubuntu, they can be easily installed using apt-get:

    $ apt-get update && apt-get install libsndfile-dev libsamplerate-dev
    
  • Create a Python 3.8 environment. Conda is recommended:

    $ conda create -n tunet python=3.8
    $ conda activate tunet
    
  • Install the requirements:

    $ pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
    

Note: the argument -f https://download.pytorch.org/whl/cu113/torch_stable.html is provided to install torch==1.10.0+cu113 (Pytorch 1.10, CUDA 11.3) inside the requirements.txt . Choose an appropriate CUDA version to your GPUs and change/remove the argument according to PyTorch documentation

3. Data preparation

In our paper, we conduct experiments on the VCTK and VIVOS datasets. You may use either one or both.

  • Download and extract the datasets:

    $ wget http://www.udialogue.org/download/VCTK-Corpus.tar.gz -O data/vctk/VCTK-Corpus.tar.gz
    $ wget https://ailab.hcmus.edu.vn/assets/vivos.tar.gz -O data/vivos/vivos.tar.gz
    $ tar -zxvf data/vctk/VCTK-Corpus.tar.gz -C data/vctk/ --strip-components=1
    $ tar -zxvf data/vivos/vivos.tar.gz -C data/vivos/ --strip-components=1
    

    After extracting the datasets, your ./data directory should look like this:

    .
    |--data
        |--vctk
            |--wav48
                |--p225
                    |--p225_001.wav
                    ...
            |--train.txt   
            |--test.txt
        |--vivos
            |--train
                |--waves
                    |--VIVOSSPK01
                        |--VIVOSSPK12_R001.wav
                        ...                
            |--test
                |--waves
                    |--VIVOSDEV01
                        |--VIVOSDEV01_R001.wav
                        ...      
            |--train.txt   
            |--test.txt
    
  • In order to load the datasets, text files that contain training and testing audio paths are required. We have prepared train.txt and test.txt files in ./data/vctk and ./data/vivos directories.

4. Run the code

Configuration

config.py is the most important file. Here, you can find all the configurations related to experiment setups, datasets, models, training, testing, etc. Although the config file has been explained thoroughly, we recommend reading our paper to fully understand each parameter.

Training

  • Adjust training hyperparameters in config.py

    Note: batch_size in this implementation is different from the batch size in the paper. Specifically, we infer " batch size" in our paper as the number of frames per batch, whereas in this repo, batch_size is the number of ** audio files** per batch. The DataLoader loads batches of audio files then chunks into frames on the fly. Since audio duration is variable, the number of frames per batch varies around 12*batch_size .

  • Run main.py:

    $ python main.py --mode train
    
  • Each run will create a version in ./lightning_logs, where the model checkpoint and hyperparameters are saved. In case you want to continue training from one of these versions, just set the argument --version of the above command to your desired version number. For example:

    # resume from version 5
    $ python main.py --mode train --version 5
    
  • To monitor the training curves as well as inspect model output visualization, run the tensorboard:

    $ tensorboard --logdir=./lightning_logs --bind_all
    

    image.png image.png

Evaluation

  • Modify config.py to change evaluation setup if necessary.
  • Run main.py with a version number to be evaluated:
    $ python main.py --mode eval --version 5
    
    This will give the mean and standard deviation of LSD, LSD-HF, and SI-SDR, respectively. During the evaluation, several output samples are saved to CONFIG.LOG.sample_path for sanity testing.

Configure a new dataset

Our implementation currently works with the VCTK and VIVOS datasets but can be easily extensible to a new one.

  • Firstly, you need to prepare train.txt and test.txt. See ./data/vivos/train.txt and ./data/vivos/test.txt for example.
  • Secondly, add a new dictionary to CONFIG.DATA.data_dir:
    {
    'root': 'path/to/data/directory',
    'train': 'path/to/train.txt',
    'test': 'path/to/test.txt'
    }
    
    Important: Make sure each line in train.txt and test.txt joining with 'root' is a valid path to its corresponding audio file.

5. Audio generation

  • In order to generate output audios, you need to either put your input samples into ./test_samples or modify CONFIG.TEST.in_dir to your input directory.

  • Run main.py:

    python main.py --mode test --version 5
    

    The generated audios are saved to CONFIG.TEST.out_dir.

    Note: checkpoint version_5 has only been trained for a few epochs for demonstration purposes. Since the code has been refactored, the checkpoint we used in the paper could not be loaded. To inference with our best checkpoint, please use the ONNX model instead.

    ONNX inferencing

    We provide ONNX inferencing scripts and the best ONNX model (converted from the best checkpoint) at lightning_logs/best_model.onnx.

    • Convert a checkpoint to an ONNX model:
      python main.py --mode onnx --version 5
      
      The converted ONNX model will be saved to lightning_logs/version_5/checkpoints.
    • Put test audios in test_samples and inference with the converted ONNX model (see inference_onnx.py for more details):
      python inference_onnx.py
      

6. Acknowledgement

We thank FPT Software for funding and providing GPU infrastructure. We also thank Microsoft for giving access to the DNSMOS Azure service.

tunet's People

Contributors

anhnht3 avatar anhnv125 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.