Git Product home page Git Product logo

wave-u-net-for-speech-enhancement's Introduction

Wave-U-Net-for-Speech-Enhancement

Implement Wave-U-Net by PyTorch, and migrate it to the speech enhancement.

Dependencies

# Make sure the /bin directory of CUDA be added to PATH enveriment variable
# Install CUPTI included with CUDA by appending the LD_LIBRARY_PATH environment variable
export PATH="/usr/local/cuda-10.0/bin:$PATH"
export LD_LIBRARY_PATH="/usr/local/cuda-10.0/lib64:$LD_LIBRARY_PATH"

# Install Anaconda, take Tsinghua mirror source and python 3.6.5 as an example
wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-5.2.0-Linux-x86_64.sh
chmod a+x Anaconda3-5.2.0-Linux-x86_64.sh
./Anaconda3-5.2.0-Linux-x86_64.sh # Press f to turn the page, the default installation is in ~/anaconda directory, the installation process will prompt to modify the PATH variable

# Create a virtual environment
conda create -n wave-u-net python=3
conda activate wave-u-net

# Install dependencies
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch  # Pytorch 1.2.0 version has been tested
conda install tensorflow-gpu  # Only for tensorboard
conda install matplotlib
pip install tqdm librosa pystoi pesq

# Clone
git clone https://github.com/haoxiangsnr/Wave-U-Net-for-Speech-Enhancement.git

Usage

There are two entry files for the current project:

  • Entry file for training models: train.py
  • Entry file for enhance noisy speech: enhancement.py

Training

Use train.py to train the model. It receives three command line parameters:

  • -h, display help information
  • -C, --config, specify the configuration file required for training
  • -R, --resume, continue training from the checkpoint of the last saved model

Syntax: python train.py [-h] -C CONFIG [-R]

E.g.:

python train.py -C config/train/train.json
# The configuration file used to train the model is "config/train/train.json"
# Use all GPUs for training

python train.py -C config/train/train.json -R
# The configuration file used to train the model is "config/train/train.json"
# Use all GPUs to continue training from the last saved model checkpoint

CUDA_VISIBLE_DEVICES=1,2 python train.py -C config/train/train.json
# The configuration file used to train the model is "config/train/train.json"
# Use GPU No.1 and 2 for training

CUDA_VISIBLE_DEVICES=-1 python train.py -C config/train/train.json
# The configuration file used to train the model is "config/train/train.json"
# Use CPU for training

Supplement:

  • Generally, the configuration files needed for training are placed in the config/train directory
  • See the "Parameter Description" section for the parameters in the training configuration file
  • The filename of the configuration file is the experiment name.

Enhancement

Use enhancement.py to enhance noisy speech, which receives the following parameters:

  • -h, --help, display help information
  • -C, --config, specify the model, the enhanced dataset, and custom args used to enhance the speech.
  • -D, --device, enhance the GPU index used, -1 means use CPU
  • -O, --output_dir, specify where to store the enhanced speech, you need to ensure that this directory exists in advance
  • -M, --model_checkpoint_path, the path of the model checkpoint, the extension of the checkpoint file is .tar or .pth

Syntax: python enhancement.py [-h] -C CONFIG [-D DEVICE] -O OUTPUT_DIR -M MODEL_CHECKPOINT_PATH

E.g.:

python enhancement.py -C config/enhancement/unet_basic.json -D 0 -O enhanced -M /media/imucs/DataDisk/haoxiang/Experiment/Wave-U-Net-for-Speech-Enhancement/smooth_l1_loss/checkpoints/model_0020.pth
# The configuration file used to enhancement is "config/enhancement/unet_basic.json". Use this file to specify the model and dataset information required for enhancement
# Use GPU with index 0
# The output directory is "enhanced/", the directory needs to be created in advance
# Specify the path of the model checkpoint

python enhancement.py -C config/enhancement/unet_basic.json -D -1 -O enhanced -M /media/imucs/DataDisk/haoxiang/Experiment/Wave-U-Net-for-Speech-Enhancement/smooth_l1_loss/checkpoints/model_0020.tar
# Use CPU for enhancement

Supplement:

  • Generally, the configuration files needed for enhancement are placed in the config/enhancement/ directory
  • See the "Parameter Description" section for the parameters in the enhancement configuration file.

Visualization

All log information generated during training will be stored in the config["root_dir"]/<config_filename>/ directory. Assuming that the configuration file for training is config/train/sample_16384.json, the value of the root_dir parameter in sample_16384.json is /home/UNet/. Then, the logs generated during the current experimental training process will be stored In the /home/UNet/sample_16384/ directory. The directory will contain the following:

  • logs/ directory: store Tensorboard related data, including loss curve, waveform file, speech file
  • checkpoints/ directory: stores all checkpoints of the model, from which you can restart training or speech enhancement
  • config.json file: backup of the training configuration file

During the training process, we can use tensorboard to start a static front-end server to visualize the log data in the relevant directory:

tensorboard --logdir config["root_dir"]/<config_filename>/

# You can use --port to specify the port of the tensorboard static server
tensorboard --logdir config["root_dir"]/<config_filename>/ --port <port>

# For example, the "root_dir" parameter in the configuration file is "/home/happy/Experiments", the configuration file name is "train_config.json", and the default port is modified to 6000. The following commands can be used:
tensorboard --logdir /home/happy/Experiments/train_config --port 6000

Directory description

During the training, multiple directories will be used, all with different purposes:

  • Main directory: the directory where the current README.md is located, storing all source code
  • Training directory: the config["root_dir"] directory in the training configuration file, which stores all experiment logs and model checkpoints of the current project
  • Experiment directory: config["root_dir"]/<experiment name>/ directory, which stores the log information of a certain experiment

Parameter Description

Training

config/train/<config_filename>.json

The log information generated during the training process will be stored inconfig["root_dir"]/<config_filename>/.

{
    "seed": 0, // Random seeds to ensure experiment repeatability
    "description": "...",  // Experiment description, will be displayed in Tensorboard later
    "root_dir": "~/Experiments/Wave-U-Net", // Directory for storing experiment results
    "cudnn_deterministic": false,
    "trainer": { // For training process
        "module": "trainer.trainer", // Which trainer
        "main": "Trainer", // The concrete class of the trainer model
        "epochs": 1200, // Upper limit of training
        "save_checkpoint_interval": 10, // Save model breakpoint interval
        "validation":{
        "interval": 10, // validation interval
         "find_max": true, // When find_max is true, if the calculated metric is the known maximum value, it will cache another copy of the current round of model checkpoint.
        "custon": {
            "visualize_audio_limit": 20, // The interval of visual audio during validation. The reason for setting this parameter is that visual audio is slow
            "visualize_waveform_limit": 20, // The interval of the visualization waveform during validation. The reason for setting this parameter is because the visualization waveform is slow
            "visualize_spectrogram_limit": 20, // Verify the interval of the visualization spectrogram. This parameter is set because the visualization spectrum is slow
            "sample_length": 16384 // See train dataset
            } 
        }
    },
    "model": {
        "module": "model.unet_basic", // Model files used for training
        "main": "Model", // Concrete class of training model
        "args": {} // Parameters passed to the model class
    },
    "loss_function": {
        "module": "model.loss", // Model file of loss function
        "main": "mse_loss", // Concrete class of loss function
        "args": {} // Parameters passed to the model class
    },
    "optimizer": {
        "lr": 0.001,
        "beta1": 0.9,
        "beat2": 0.009
    },
    "train_dataset": {
        "module": "dataset.waveform_dataset", // Store the training set model file
        "main": "Dataset", // Concrete class of training dataset
        "args": { // The parameters passed to the training set class, see the specific training set class for details
            "dataset": "~/Datasets/SEGAN_Dataset/train_dataset.txt",
            "limit": null,
            "offset": 0,
            "sample_length": 16384,
            "mode":"train"
        }
    },
    "validation_dataset": {
        "module": "dataset.waveform_dataset",
        "main": "Dataset",
        "args": {
            "dataset": "~/Datasets/SEGAN_Dataset/test_dataset.txt",
            "limit": 400,
            "offset": 0,
            "mode":"validation"
        }
    },
    "train_dataloader": {
        "batch_size": 120,
        "num_workers": 40, // How many threads to start to preprocess the data
        "shuffle": true,
        "pin_memory":true
    }
}

Enhancement

config/enhancement/*.json

{
    "model": {
        "module": "model.unet_basic",  // Store the model file
        "main": "UNet",  // The specific model class in the file
        "args": {}  // Parameters passed to the model class
    },
    "dataset": {
        "module": "dataset.waveform_dataset_enhancement",  // Store the enhancement dataset file
        "main": "WaveformDataset",  // Concrete class of enhacnement dataset
        "args": {  // The parameters passed to the dataset class, see the specific enhancement dataset class for details
            "dataset": "/home/imucs/tmp/UNet_and_Inpainting/data.txt",
            "limit": 400,
            "offset": 0,
            "sample_length": 16384
        }
    },
    "custom": {
        "sample_length": 16384
    }
}

During the enhancement, only the path of the noisy speech can be listed in the *.txt file, similar to this:

# enhancement_*.txt

/home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Clean.wav
/home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_Inpainting_200.wav
/home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_Inpainting_270.wav
/home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_UNet.wav
/home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Mixture.wav

TODO

  • Use full-length speech for validation
  • Enhancement script

wave-u-net-for-speech-enhancement's People

Contributors

747929791 avatar haoxiangsnr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

wave-u-net-for-speech-enhancement's Issues

test.py file

Dear,

Thanks for your repo and Implementation, but I can't find the test.py

Training input length

Great work! Thanks for sharing!
May I ask for each training sample how long is the input and output? for example, in the original paper: Wave-U-Net: A Multi-scale NN for end-to-end Audio Source Separation, the author uses about 7s audio segment to get a 0.7s separated source.
So how about your data?

working train.py Error!

The amount of parameters is 10.132802 million.
============== 1 epoch ==============
[0 seconds] Begin training...
Traceback (most recent call last):
File "train.py", line 64, in
main(configuration, resume=args.resume)
File "train.py", line 51, in main
trainer.train().cuda()
File "/root/autodl-tmp/trainer/base_trainer.py", line 194, in train
self._train_epoch(epoch)
File "/root/autodl-tmp/trainer/trainer.py", line 30, in _train_epoch
for i, (mixture, clean, name) in enumerate(self.train_data_loader):
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in next
data = self._next_data()
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/root/miniconda3/lib/python3.8/site-packages/torch/_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
return [default_collate(samples) for samples in transposed]
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 84, in
return [default_collate(samples) for samples in transposed]
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 64, in default_collate
return default_collate([torch.as_tensor(b) for b in batch])
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 56, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [1, 16384] at entry 0 and [1, 11293] at entry 3

There are two problems:
1.RuntimeError: Caught RuntimeError in DataLoader worker process 0.
2.RuntimeError: stack expects each tensor to be equal size, but got [1, 16384] at entry 0 and [1, 11293] at entry 3

training SNR

请问这个是UNetGAN论文中使用的baseline吗?这个模型的训练SNR是和原文一样吗,还是和UNetGAN一样,使用-15&-10&-5&0dB呢?

Assertion Error t len(mixture) == len(clean) == len(enhanced)

when I am trying to run the code it gives me an error in this condition
assert len(mixture) == len(clean) == len(enhanced)
I printed the len of each and found the len of enhanced and clean is equal but len of the mixture is greater than both.

I hope you can help me as soon as possible

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.