Git Product home page Git Product logo

vaessl-doa's Introduction

Semi-supervised source localization in reverberant environments with deep generative modeling

Official implementation of Semi-supervised source localization with deep generative modeling, (Bianco et al. 2021) (paper). In this machine learning-based approach to acoustic source localization, a variational autoencoder (VAE) is trained to generate the relative transfer function (RTF) phase for two microphones. The VAE model is trained in parallel with a classifier network, which estimates the direction of arrival for an acoustic source. Both models are trained on both labelled and unlabeled RTF-phase sequences, generated from speech in reverberant environments.

This deep generative semi-supervised approach performs well relative to fully-supervised and conventional, signal processing-based source localization approaches when large, but sparsely labeled, datasets are available. Further, the trained VAE can conditionally generate RTF-phase sequences.

As part of this study, a new room acoustics dataset was collected onbtained in a classroom at the Technical University of Denmark (DTU) (Fernandez-Grande et al. 2021) (dataset). This dataset included off-grid and off-range source locations, to test model generalization. To obtain reverberant speech data, speech segments were obtained from the Librispeech devlopment corpus and convolved with recorded room IRs. More details of the process are given in our paper.

The neural networks and variational inference were implemented with Pytorch and the Pyro probabilistic programming library.

Requirements

This code was developed in a virtual environment managed by Anaconda. The requirements.yml file contains packages managed by conda and pip, which can be installed using

conda env create --file requirements.yml

The default_paths.json file contains paths, datasets, and models for training and evaluation. These can either be changed in the .json file, or at the command line.

Datasets

This distribution is configured to use reverberant speech obtain using DTU dataset IRs. The datasets, processed training and validation data, are available for download here. Each file contains the raw waveform from twenty 2-3 second audio clips of reverberant speech from each of the DOAs. Per the default paths, these files should be placed in a folder named 'data' in the main repository directory.

Training

Three pretrained models for VAE-SSL and fully-supervised CNN are provded in this repository. To train your own VAE-SSL model, using the virtual environent based on the requirements, run

python vaessl_train.py --cuda-id <your cuda ID>\
                        --path-save <your path for saving the trained model>

While you can technically train VAE-SSL on a CPU (this is set as default), it is prohibitively slow to do so.

Code is also provided for training and evaluating a fully-supervised convolutional neural network (CNN). This CNN is the same architecture used by the VAE-SSL classifier.

To train your own CNN model, run

python cnn_train.py --cuda-id <your cuda ID>\
                    --path-save <your path for saving the trained model>

Evaluation

The trained models can be evaluated by

python vaessl_eval.py --cuda-id <your cuda ID>\
                       --path-save <path to saved model>
python cnn_eval.py --cuda-id <your cuda ID>\
                   --path-save <path to saved model>

Calling the script without any flags will evaluate using the pretrained models and datasets.

Conventional DOA estimation

This repository has an implementation of conventional DOA estimation algorithms, which uses the Pyroomacoustics library. In the paper, the ML-based approaches are compared with SRP-PHAT and MUSIC DOA-estimation. To obtain source localization results for these methods, use

python conventional_doa.py --algo <the name of the algorithm, 'SRP' or 'MUSIC'>

Visualization

This repository also includes a Jupyter notebook for visualizing the reconstruction and conditional generation of RTF-phase sequences using VAE-SSL. The notebook contains the code used to generate Figures 4--8 in the paper.

Attribution

If you use this code in your research, cite via the following BibTeX:

@article{bianco2021semi,
  title={Semi-supervised source localization in reverberant environments with deep generative modeling},
  author={Bianco, Michael J and Gannot, Sharon and Fernandez-Grande, Efren and Gerstoft, Peter},
  journal={IEEE Access},
  year={2021},
  publisher={IEEE}
}

vaessl-doa's People

Contributors

mikebianco avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.