Git Product home page Git Product logo

nase's Introduction

Noise-aware Speech Enhancement using Diffusion Probabilistic Model

This repository contains the official PyTorch implementations for our paper:

Our code is based on prior work SGMSE+.

Installation

  • Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
  • Install the package dependencies via pip install -r requirements.txt.
  • If using W&B logging (default):
    • Set up a wandb.ai account
    • Log in via wandb login before running our code.
  • If not using W&B logging:
    • Pass the option --no_wandb to train.py.
    • Your logs will be stored as local TensorBoard logs. Run tensorboard --logdir logs/ to see them.

Pretrained checkpoints

Usage:

  • For resuming training, you can use the --resume_from_checkpoint option of train.py.
  • For evaluating these checkpoints, use the --ckpt option of enhancement.py (see section Evaluation below).

Training

Training is done by executing train.py. A minimal running example with default settings can be run with:

python train.py --base_dir <your_base_dir> --inject_type <inject_type> --pretrain_class_model <pretrained_beats>

where your_base_dir should be a path to a folder containing subdirectories train/ and valid/ (optionally test/ as well). Each subdirectory must itself have two subdirectories clean/ and noisy/, with the same filenames present in both. We currently only support training with .wav files. inject_type should be chosen from ["addition", "concat", "cross-attention"]. pretrained_beats should be the path to pre-trained BEATs.

The full command is also included in train.sh. To see all available training options, run python train.py --help.

Evaluation

To evaluate on a test set, run

python enhancement.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir> --ckpt <path_to_model_checkpoint>

to generate the enhanced .wav files, and subsequently run

python calc_metrics.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir>

to calculate and output the instrumental metrics.

Both scripts should receive the same --test_dir and --enhanced_dir parameters. The --cpkt parameter of enhancement.py should be the path to a trained model checkpoint, as stored by the logger in logs/.

You may refer to our full commands included in enhancement.sh and calc_metrics.sh.

Citations

We kindly hope you can cite our paper in your publication when using our research or code:

@article{hu2023noise,
  title={Noise-aware Speech Enhancement using Diffusion Probabilistic Model},
  author={Hu, Yuchen and Chen, Chen and Li, Ruizhe and Zhu, Qiushi and Chng, Eng Siong},
  journal={arXiv preprint arXiv:2307.08029},
  year={2023}
}

nase's People

Contributors

yuchen005 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.