Git Product home page Git Product logo

puresound's People

Contributors

mcw519 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

puresound's Issues

How to train tse?

Hi wu
How to train the tse model? I use this command:
CUDA_VISIBLE_DEVICES=2 python main.py conf/libri2mix_max_2spk_clean_16k_1c.yaml --backend cuda
Although I set the batch size to 1, the following error is still reported. Is there any way to solve it?


init loss function: sisnr
Initialized a multi-task model, sharing a same speech encoder.
Multi-task training has two loss function.
Current task label: 1
---------------Verbose logging---------------
Current training mode is: False
Total params: 6375442
Lookahead(samples): 16
Receptive Fields(samples): infinite
Current training mode is: True
---------------Verbose logging---------------
  0%|                                                                                                                            | 0/1737 [00:00<?, ?it/s]epoch: 0, iter: 1, batch_loss: 44.908958435058594, signal_loss: 39.039031982421875, class_loss: 117.39849853515625
  0%|                                                                                                                    | 1/1737 [00:01<46:59,  1.62s/it]epoch: 0, iter: 2, batch_loss: 29.754398345947266, signal_loss: 23.969154357910156, class_loss: 115.7048568725586
  0%|                                                                                                                  | 1/1737 [00:02<1:17:31,  2.68s/it]
Traceback (most recent call last):
  File "main.py", line 466, in <module>
    main(config)
  File "main.py", line 154, in main
    trainer.train()
  File "/home/yangjie/PureSound/egs/tse/puresound/task/base.py", line 383, in train
    loss = self.train_one_epoch(current_epoch=epoch)
  File "/home/yangjie/PureSound/egs/tse/puresound/task/tse.py", line 597, in train_one_epoch
    loss.backward()
  File "/home/yangjie/miniconda3/envs/pyannote/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/yangjie/miniconda3/envs/pyannote/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA out of memory. Tried to allocate 570.00 MiB (GPU 0; 23.70 GiB total capacity; 20.34 GiB already allocated; 380.81 MiB free; 21.50 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My GPU is RTX3090Ti
My conf:

DATASET:
  type: TSE
  sample_rate: 16000
  max_length: 3 # input snipit length
  train: /home/yangjie/PureSound/egs/tse/train-100-clean # train set folder
  dev: /home/yangjie/PureSound/egs/tse/dev-clean # dev set folder
  eval: /home/yangjie/PureSound/egs/tse/test-clean # eval set folder for tensorboard logging
  noise_folder: # if not None, applies noise inject
  rir_folder: # if not None, applies RIR
  rir_mode: # image/direct/early
  vol_perturbed: # Tuple
  speed_perturbed: False
  perturb_frequency_response: False
  single_spk_prob: 0.
  inactive_training: 0.
  enroll_rule: fixed_length
  enroll_augment: False

MODEL:
  type: tse_skim_v0_causal # model-name
  
LOSS:
  sig_loss: sisnr
  sig_threshold: #-30
  alpha: 0.05
  cls_loss: ge2e
  cls_loss_other:
  embed_dim: 192 # for aamsoftmax
  n_class: 251 # for aamsoftmax
  margin: 0.2 # for aamsoftmax
  scale: 30 # for aamsoftmax

OPTIMIZER:
  gradiend_clip: 10
  lr: 0.0001
  num_epochs_decay: 0
  lr_scheduler: Plateau
  mode: min
  patience: 5
  gamma: 0.5
  beta1: 0.9
  beta2: 0.999
  weight_decay: 0.
  multi_rate: False

TRAIN:
  num_epochs: 200
  resume_epoch:
  contrastive_learning: True
  p_spks: 12
  p_utts: 4
  repeat: 3
  batch_size: 1
  multi_gpu: False
  num_workers: 2
  model_average:
  use_tensorboard: True
  model_save_dir: models # model save folder
  log_dir: logs # logging save folder

origin libri2mix_max_2spk_clean_16k_1c.conf miss some fields , it will cause a crash. I have added them compared to TSE.yaml

diff --git a/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml b/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml
index 6b8b089..8de0aae 100644
--- a/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml
+++ b/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml
@@ -2,14 +2,15 @@ DATASET:
   type: TSE
   sample_rate: 16000
   max_length: 3 # input snipit length
-  train: train # train set folder
-  dev: dev # dev set folder
-  eval: eval # eval set folder for tensorboard logging
+  train: /home/yangjie/PureSound/egs/tse/train-100-clean # train set folder
+  dev: /home/yangjie/PureSound/egs/tse/dev-clean # dev set folder
+  eval: /home/yangjie/PureSound/egs/tse/test-clean # eval set folder for tensorboard logging
   noise_folder: # if not None, applies noise inject
   rir_folder: # if not None, applies RIR
   rir_mode: # image/direct/early
   vol_perturbed: # Tuple
   speed_perturbed: False
+  perturb_frequency_response: False
   single_spk_prob: 0.
   inactive_training: 0.
   enroll_rule: fixed_length
@@ -23,6 +24,7 @@ LOSS:
   sig_threshold: #-30
   alpha: 0.05
   cls_loss: ge2e
+  cls_loss_other:
   embed_dim: 192 # for aamsoftmax
   n_class: 251 # for aamsoftmax
   margin: 0.2 # for aamsoftmax
@@ -39,6 +41,7 @@ OPTIMIZER:
   beta1: 0.9
   beta2: 0.999
   weight_decay: 0.
+  multi_rate: False

 TRAIN:
   num_epochs: 200
@@ -47,9 +50,9 @@ TRAIN:
   p_spks: 12
   p_utts: 4
   repeat: 3
-  batch_size: 12
-  multi_gpu: True
-  num_workers: 10
+  batch_size: 1
+  multi_gpu: False
+  num_workers: 2
   model_average:
   use_tensorboard: True
   model_save_dir: models # model save folder

There are some strange things here no matter how much I set the batch_size to, it is always 1/1737. According to my understanding, different batch sizes should be different from 1737. I also print this hparam["TRAIN"]["batch_size”] in the init_dataloader function "] The setting is indeed effective.

Thanks!

Is it necessary to limit the version of pytoch, I get very poor results on all 4 ns models?

inference results:
image
ns_data.zip

Package                   Version
------------------------- ------------
absl-py                   1.4.0
aiohttp                   3.8.4
aiosignal                 1.3.1
alembic                   1.9.4
antlr4-python3-runtime    4.8
appdirs                   1.4.4
asteroid                  0.6.0
asteroid-filterbanks      0.4.0
asttokens                 2.2.1
async-timeout             4.0.2
attrdict                  2.0.1
attrs                     22.2.0
audioread                 3.0.0
backcall                  0.2.0
backports.cached-property 1.0.2
brotlipy                  0.7.0
brouhaha                  0.9.0
cached-property           1.5.2
cachetools                5.3.0
certifi                   2022.12.7
cffi                      1.15.1
charset-normalizer        2.0.4
click                     8.1.3
cmaes                     0.9.1
colorama                  0.4.6
coloredlogs               15.0.1
colorlog                  6.7.0
comm                      0.1.2
commonmark                0.9.1
conda                     23.1.0
conda-package-handling    2.0.2
conda_package_streaming   0.7.0
contourpy                 1.0.7
cryptography              38.0.4
cycler                    0.11.0
Cython                    0.29.34
debugpy                   1.6.6
decorator                 5.1.1
DeepFilterDataLoader      0.4.0
DeepFilterLib             0.4.0
deepfilternet             0.4.0
docopt                    0.6.2
einops                    0.3.2
et-xmlfile                1.1.0
exceptiongroup            1.1.1
executing                 1.2.0
filelock                  3.9.0
flatbuffers               23.1.21
flit_core                 3.6.0
fonttools                 4.38.0
frozenlist                1.3.3
fsspec                    2023.1.0
future                    0.18.3
google-auth               2.16.1
google-auth-oauthlib      0.4.6
greenlet                  2.0.2
grpcio                    1.51.3
hmmlearn                  0.2.8
huggingface-hub           0.12.1
humanfriendly             10.0
Hydra                     2.5
hydra-core                1.1.0
HyperPyYAML               1.1.0
icecream                  2.1.3
idna                      3.4
importlib-metadata        6.0.0
importlib-resources       5.12.0
iniconfig                 2.0.0
ipyhton                   0.1
ipykernel                 6.21.2
ipython                   8.10.0
jedi                      0.18.2
joblib                    1.2.0
julius                    0.2.7
jupyter_client            8.0.3
jupyter_core              5.2.0
kiwisolver                1.4.4
librosa                   0.9.2
llvmlite                  0.39.1
loguru                    0.7.0
Mako                      1.2.4
Markdown                  3.4.1
MarkupSafe                2.1.2
matplotlib                3.7.0
matplotlib-inline         0.1.6
MinDAEC                   0.0.2
mir-eval                  0.7
mkl-fft                   1.3.1
mkl-random                1.2.2
mkl-service               2.4.0
mpmath                    1.2.1
multidict                 6.0.4
multiprocessing-logging   0.3.4
natsort                   8.3.1
nest-asyncio              1.5.6
networkx                  2.8.8
numba                     0.56.4
numpy                     1.23.5
oauthlib                  3.2.2
omegaconf                 2.1.2
onnx                      1.13.1
onnx-simplifier           0.4.17
onnxruntime               1.14.0
openpyxl                  3.1.2
openyxl                   0.1
optuna                    3.1.0
packaging                 23.0
pandas                    1.5.3
parso                     0.8.3
pb-bss-eval               0.0.2
pesq                      0.0.4
pexpect                   4.8.0
pickleshare               0.7.5
Pillow                    9.3.0
pip                       22.3.1
pipdeptree                2.5.0
platformdirs              3.0.0
pluggy                    1.0.0
pooch                     1.6.0
primePy                   1.3
prompt-toolkit            3.0.37
protobuf                  3.20.3
psutil                    5.9.4
ptyprocess                0.7.0
pure-eval                 0.2.2
pyannote.audio            2.1.1
pyannote.core             4.5
pyannote.database         4.1.3
pyannote.metrics          3.2.1
pyannote.pipeline         2.3
pyasn1                    0.4.8
pyasn1-modules            0.2.8
pybind11                  2.10.4
pycosat                   0.6.4
pycparser                 2.21
pyDeprecate               0.3.2
Pygments                  2.14.0
pyOpenSSL                 22.0.0
pyparsing                 3.0.9
pyreadline                2.1
pyroomacoustics           0.7.3
PySocks                   1.7.1
pystoi                    0.3.3
pytest                    7.3.1
python-dateutil           2.8.2
pytorch-lightning         1.6.5
pytorch-metric-learning   1.7.3
pytorch-ranger            0.1.1
pytz                      2022.7.1
PyYAML                    6.0
pyzmq                     25.0.0
requests                  2.28.1
requests-oauthlib         1.3.1
resampy                   0.4.2
rich                      12.6.0
rsa                       4.9
ruamel.yaml               0.17.21
ruamel.yaml.clib          0.2.7
scikit-learn              1.2.1
scipy                     1.10.1
semver                    2.13.0
sentencepiece             0.1.97
setuptools                65.6.3
shellingham               1.5.0.post1
simplejson                3.18.3
singledispatchmethod      1.0
six                       1.16.0
sortedcontainers          2.4.0
SoundFile                 0.10.3.post1
speechbrain               0.5.13
SQLAlchemy                2.0.4
stack-data                0.6.2
sympy                     1.11.1
tabulate                  0.9.0
tensorboard               2.12.0
tensorboard-data-server   0.7.0
tensorboard-plugin-wit    1.8.1
threadpoolctl             3.1.0
toml                      0.10.2
tomli                     2.0.1
toolz                     0.12.0
torch                     1.11.0
torch-audiomentations     0.11.0
torch-optimizer           0.1.0
torch-pitch-shift         1.2.2
torch-stoi                0.1.2
torch-tb-profiler         0.4.1
torchaudio                0.11.0
torchinfo                 1.7.2
torchmetrics              0.7.3
torchsummary              1.5.1
torchvision               0.12.0
tornado                   6.2
tqdm                      4.64.1
traitlets                 5.9.0
typer                     0.7.0
typing_extensions         4.4.0
urllib3                   1.26.14
wcwidth                   0.2.6
Werkzeug                  2.2.3
wheel                     0.38.4
yarl                      1.8.2
zipp                      3.14.0
zstandard                 0.19.0

_align_waveform may be a bug?

Hi Wu
when config speed_perturbed = True, There will be a crash. I located that _align_waveform this function does not align the data, and there maybe some logic errors. After the modification as follows, there is no crash. Thanks!

    def _align_waveform(
        self, enh_wav: torch.Tensor, ref_wav: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Assume last axis is the time."""
        enh_wav_l = enh_wav.shape[-1]
        ref_wav_l = ref_wav.shape[-1]
        if enh_wav_l != ref_wav_l:
            if ref_wav_l < enh_wav_l:
                # align from last
                pad_num = enh_wav_l - ref_wav_l
                ref_wav = F.pad(ref_wav, (pad_num, 0))
            else:
                # align from begin
                ref_wav = ref_wav[..., :enh_wav_l]   // change here
        return enh_wav, ref_wav

about tse inference

Hi
I want to test your tse model, but I have no success running your tse/demo/demo_app.py on windows.
So I carefully read the content under the tse folder and wrote a simple test function, I get different results for streaming and non-streaming inference, is it right? Is my test code correct? Thanks!

import torch
from utils import DemoSpeakerNet, DemoTseNet
import soundfile
from model import init_model
if __name__ == '__main__':
    ckpt = torch.load("onnx_tse/skim_causal_460_wNoise_IS_tsdr.ckpt", map_location="cpu")
    enroll_wav_file = 'onnx_tse/s1/61-70968-0003_5105-28240-0000.wav'
    mix_wav_file = 'onnx_tse/mix/61-70968-0003_5105-28240-0000.wav'   
    enroll_wav  = torch.from_numpy(soundfile.read(enroll_wav_file, dtype="float32")[0].reshape(1, -1))
    mix_wav  = torch.from_numpy(soundfile.read(mix_wav_file, dtype="float32")[0].reshape(1, -1)) 
    speaker_net = DemoSpeakerNet()
    speaker_net.load_state_dict(ckpt["state_dict"], strict=False)
    speaker_net.eval()
    tse_net = DemoTseNet()
    tse_net.load_state_dict(ckpt["state_dict"], strict=False)
    tse_net.eval()
    tse_net.masker.init_status()
    speaker_embedding = speaker_net.get_speaker_embedding(enroll_wav) 
    stream_enh_wav  = tse_net.streaming_inference_chunk(mix_wav, speaker_embedding)
    soundfile.write('onnx_tse/enh_stream.wav', stream_enh_wav, 16000)

    model = init_model('tse_skim_v0_causal', verbose=False)
    model.load_state_dict(ckpt["state_dict"], strict=False)  # ignore loss's weight
    model.eval()      
    enh_wav = model.inference(mix_wav, enroll_wav)
    enh_wav = enh_wav.detach().cpu().numpy().reshape(-1)
    soundfile.write('onnx_tse/enh.wav', enh_wav, 16000)

    print('Hello world')

data.zip

How to further improve the effect of tse?

Hi
Thank you very much for open-sourcing this project. This is a great open-source project, which is very helpful to me, because I just came into contact with tse recently, but currently there are relatively few tse projects/demo on github.

I first carefully read the source code, and then ran tse/demo_app, using skim_causal_460_wNoise_IS_tsdr.ckpt to do the test, the effect is not very satisfactory.

At the same time, the above model was also used for offline testing. Using the synth data in https://github.com/eeskimez/pse-samples for comparison, it was found that the effect was worse than that of pdcattunet.

Based on your current tse results, is there any possibility to improve the effect? Do you have some suggestions?
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.