mcw519 / puresound Goto Github PK
View Code? Open in Web Editor NEWMake the sound you hear pure and clean by deep learning.
Make the sound you hear pure and clean by deep learning.
Hi wu
How to train the tse model? I use this command:
CUDA_VISIBLE_DEVICES=2 python main.py conf/libri2mix_max_2spk_clean_16k_1c.yaml --backend cuda
Although I set the batch size to 1, the following error is still reported. Is there any way to solve it?
init loss function: sisnr
Initialized a multi-task model, sharing a same speech encoder.
Multi-task training has two loss function.
Current task label: 1
---------------Verbose logging---------------
Current training mode is: False
Total params: 6375442
Lookahead(samples): 16
Receptive Fields(samples): infinite
Current training mode is: True
---------------Verbose logging---------------
0%| | 0/1737 [00:00<?, ?it/s]epoch: 0, iter: 1, batch_loss: 44.908958435058594, signal_loss: 39.039031982421875, class_loss: 117.39849853515625
0%| | 1/1737 [00:01<46:59, 1.62s/it]epoch: 0, iter: 2, batch_loss: 29.754398345947266, signal_loss: 23.969154357910156, class_loss: 115.7048568725586
0%| | 1/1737 [00:02<1:17:31, 2.68s/it]
Traceback (most recent call last):
File "main.py", line 466, in <module>
main(config)
File "main.py", line 154, in main
trainer.train()
File "/home/yangjie/PureSound/egs/tse/puresound/task/base.py", line 383, in train
loss = self.train_one_epoch(current_epoch=epoch)
File "/home/yangjie/PureSound/egs/tse/puresound/task/tse.py", line 597, in train_one_epoch
loss.backward()
File "/home/yangjie/miniconda3/envs/pyannote/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/yangjie/miniconda3/envs/pyannote/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA out of memory. Tried to allocate 570.00 MiB (GPU 0; 23.70 GiB total capacity; 20.34 GiB already allocated; 380.81 MiB free; 21.50 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
My GPU is RTX3090Ti
My conf:
DATASET:
type: TSE
sample_rate: 16000
max_length: 3 # input snipit length
train: /home/yangjie/PureSound/egs/tse/train-100-clean # train set folder
dev: /home/yangjie/PureSound/egs/tse/dev-clean # dev set folder
eval: /home/yangjie/PureSound/egs/tse/test-clean # eval set folder for tensorboard logging
noise_folder: # if not None, applies noise inject
rir_folder: # if not None, applies RIR
rir_mode: # image/direct/early
vol_perturbed: # Tuple
speed_perturbed: False
perturb_frequency_response: False
single_spk_prob: 0.
inactive_training: 0.
enroll_rule: fixed_length
enroll_augment: False
MODEL:
type: tse_skim_v0_causal # model-name
LOSS:
sig_loss: sisnr
sig_threshold: #-30
alpha: 0.05
cls_loss: ge2e
cls_loss_other:
embed_dim: 192 # for aamsoftmax
n_class: 251 # for aamsoftmax
margin: 0.2 # for aamsoftmax
scale: 30 # for aamsoftmax
OPTIMIZER:
gradiend_clip: 10
lr: 0.0001
num_epochs_decay: 0
lr_scheduler: Plateau
mode: min
patience: 5
gamma: 0.5
beta1: 0.9
beta2: 0.999
weight_decay: 0.
multi_rate: False
TRAIN:
num_epochs: 200
resume_epoch:
contrastive_learning: True
p_spks: 12
p_utts: 4
repeat: 3
batch_size: 1
multi_gpu: False
num_workers: 2
model_average:
use_tensorboard: True
model_save_dir: models # model save folder
log_dir: logs # logging save folder
origin libri2mix_max_2spk_clean_16k_1c.conf miss some fields , it will cause a crash. I have added them compared to TSE.yaml
diff --git a/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml b/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml
index 6b8b089..8de0aae 100644
--- a/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml
+++ b/egs/tse/conf/libri2mix_max_2spk_clean_16k_1c.yaml
@@ -2,14 +2,15 @@ DATASET:
type: TSE
sample_rate: 16000
max_length: 3 # input snipit length
- train: train # train set folder
- dev: dev # dev set folder
- eval: eval # eval set folder for tensorboard logging
+ train: /home/yangjie/PureSound/egs/tse/train-100-clean # train set folder
+ dev: /home/yangjie/PureSound/egs/tse/dev-clean # dev set folder
+ eval: /home/yangjie/PureSound/egs/tse/test-clean # eval set folder for tensorboard logging
noise_folder: # if not None, applies noise inject
rir_folder: # if not None, applies RIR
rir_mode: # image/direct/early
vol_perturbed: # Tuple
speed_perturbed: False
+ perturb_frequency_response: False
single_spk_prob: 0.
inactive_training: 0.
enroll_rule: fixed_length
@@ -23,6 +24,7 @@ LOSS:
sig_threshold: #-30
alpha: 0.05
cls_loss: ge2e
+ cls_loss_other:
embed_dim: 192 # for aamsoftmax
n_class: 251 # for aamsoftmax
margin: 0.2 # for aamsoftmax
@@ -39,6 +41,7 @@ OPTIMIZER:
beta1: 0.9
beta2: 0.999
weight_decay: 0.
+ multi_rate: False
TRAIN:
num_epochs: 200
@@ -47,9 +50,9 @@ TRAIN:
p_spks: 12
p_utts: 4
repeat: 3
- batch_size: 12
- multi_gpu: True
- num_workers: 10
+ batch_size: 1
+ multi_gpu: False
+ num_workers: 2
model_average:
use_tensorboard: True
model_save_dir: models # model save folder
There are some strange things here no matter how much I set the batch_size to, it is always 1/1737. According to my understanding, different batch sizes should be different from 1737. I also print this hparam["TRAIN"]["batch_size”] in the init_dataloader function "] The setting is indeed effective.
Thanks!
inference results:
ns_data.zip
Package Version
------------------------- ------------
absl-py 1.4.0
aiohttp 3.8.4
aiosignal 1.3.1
alembic 1.9.4
antlr4-python3-runtime 4.8
appdirs 1.4.4
asteroid 0.6.0
asteroid-filterbanks 0.4.0
asttokens 2.2.1
async-timeout 4.0.2
attrdict 2.0.1
attrs 22.2.0
audioread 3.0.0
backcall 0.2.0
backports.cached-property 1.0.2
brotlipy 0.7.0
brouhaha 0.9.0
cached-property 1.5.2
cachetools 5.3.0
certifi 2022.12.7
cffi 1.15.1
charset-normalizer 2.0.4
click 8.1.3
cmaes 0.9.1
colorama 0.4.6
coloredlogs 15.0.1
colorlog 6.7.0
comm 0.1.2
commonmark 0.9.1
conda 23.1.0
conda-package-handling 2.0.2
conda_package_streaming 0.7.0
contourpy 1.0.7
cryptography 38.0.4
cycler 0.11.0
Cython 0.29.34
debugpy 1.6.6
decorator 5.1.1
DeepFilterDataLoader 0.4.0
DeepFilterLib 0.4.0
deepfilternet 0.4.0
docopt 0.6.2
einops 0.3.2
et-xmlfile 1.1.0
exceptiongroup 1.1.1
executing 1.2.0
filelock 3.9.0
flatbuffers 23.1.21
flit_core 3.6.0
fonttools 4.38.0
frozenlist 1.3.3
fsspec 2023.1.0
future 0.18.3
google-auth 2.16.1
google-auth-oauthlib 0.4.6
greenlet 2.0.2
grpcio 1.51.3
hmmlearn 0.2.8
huggingface-hub 0.12.1
humanfriendly 10.0
Hydra 2.5
hydra-core 1.1.0
HyperPyYAML 1.1.0
icecream 2.1.3
idna 3.4
importlib-metadata 6.0.0
importlib-resources 5.12.0
iniconfig 2.0.0
ipyhton 0.1
ipykernel 6.21.2
ipython 8.10.0
jedi 0.18.2
joblib 1.2.0
julius 0.2.7
jupyter_client 8.0.3
jupyter_core 5.2.0
kiwisolver 1.4.4
librosa 0.9.2
llvmlite 0.39.1
loguru 0.7.0
Mako 1.2.4
Markdown 3.4.1
MarkupSafe 2.1.2
matplotlib 3.7.0
matplotlib-inline 0.1.6
MinDAEC 0.0.2
mir-eval 0.7
mkl-fft 1.3.1
mkl-random 1.2.2
mkl-service 2.4.0
mpmath 1.2.1
multidict 6.0.4
multiprocessing-logging 0.3.4
natsort 8.3.1
nest-asyncio 1.5.6
networkx 2.8.8
numba 0.56.4
numpy 1.23.5
oauthlib 3.2.2
omegaconf 2.1.2
onnx 1.13.1
onnx-simplifier 0.4.17
onnxruntime 1.14.0
openpyxl 3.1.2
openyxl 0.1
optuna 3.1.0
packaging 23.0
pandas 1.5.3
parso 0.8.3
pb-bss-eval 0.0.2
pesq 0.0.4
pexpect 4.8.0
pickleshare 0.7.5
Pillow 9.3.0
pip 22.3.1
pipdeptree 2.5.0
platformdirs 3.0.0
pluggy 1.0.0
pooch 1.6.0
primePy 1.3
prompt-toolkit 3.0.37
protobuf 3.20.3
psutil 5.9.4
ptyprocess 0.7.0
pure-eval 0.2.2
pyannote.audio 2.1.1
pyannote.core 4.5
pyannote.database 4.1.3
pyannote.metrics 3.2.1
pyannote.pipeline 2.3
pyasn1 0.4.8
pyasn1-modules 0.2.8
pybind11 2.10.4
pycosat 0.6.4
pycparser 2.21
pyDeprecate 0.3.2
Pygments 2.14.0
pyOpenSSL 22.0.0
pyparsing 3.0.9
pyreadline 2.1
pyroomacoustics 0.7.3
PySocks 1.7.1
pystoi 0.3.3
pytest 7.3.1
python-dateutil 2.8.2
pytorch-lightning 1.6.5
pytorch-metric-learning 1.7.3
pytorch-ranger 0.1.1
pytz 2022.7.1
PyYAML 6.0
pyzmq 25.0.0
requests 2.28.1
requests-oauthlib 1.3.1
resampy 0.4.2
rich 12.6.0
rsa 4.9
ruamel.yaml 0.17.21
ruamel.yaml.clib 0.2.7
scikit-learn 1.2.1
scipy 1.10.1
semver 2.13.0
sentencepiece 0.1.97
setuptools 65.6.3
shellingham 1.5.0.post1
simplejson 3.18.3
singledispatchmethod 1.0
six 1.16.0
sortedcontainers 2.4.0
SoundFile 0.10.3.post1
speechbrain 0.5.13
SQLAlchemy 2.0.4
stack-data 0.6.2
sympy 1.11.1
tabulate 0.9.0
tensorboard 2.12.0
tensorboard-data-server 0.7.0
tensorboard-plugin-wit 1.8.1
threadpoolctl 3.1.0
toml 0.10.2
tomli 2.0.1
toolz 0.12.0
torch 1.11.0
torch-audiomentations 0.11.0
torch-optimizer 0.1.0
torch-pitch-shift 1.2.2
torch-stoi 0.1.2
torch-tb-profiler 0.4.1
torchaudio 0.11.0
torchinfo 1.7.2
torchmetrics 0.7.3
torchsummary 1.5.1
torchvision 0.12.0
tornado 6.2
tqdm 4.64.1
traitlets 5.9.0
typer 0.7.0
typing_extensions 4.4.0
urllib3 1.26.14
wcwidth 0.2.6
Werkzeug 2.2.3
wheel 0.38.4
yarl 1.8.2
zipp 3.14.0
zstandard 0.19.0
Hi Wu
when config speed_perturbed = True, There will be a crash. I located that _align_waveform
this function does not align the data, and there maybe some logic errors. After the modification as follows, there is no crash. Thanks!
def _align_waveform(
self, enh_wav: torch.Tensor, ref_wav: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Assume last axis is the time."""
enh_wav_l = enh_wav.shape[-1]
ref_wav_l = ref_wav.shape[-1]
if enh_wav_l != ref_wav_l:
if ref_wav_l < enh_wav_l:
# align from last
pad_num = enh_wav_l - ref_wav_l
ref_wav = F.pad(ref_wav, (pad_num, 0))
else:
# align from begin
ref_wav = ref_wav[..., :enh_wav_l] // change here
return enh_wav, ref_wav
Hi
I want to test your tse model, but I have no success running your tse/demo/demo_app.py on windows.
So I carefully read the content under the tse folder and wrote a simple test function, I get different results for streaming and non-streaming inference, is it right? Is my test code correct? Thanks!
import torch
from utils import DemoSpeakerNet, DemoTseNet
import soundfile
from model import init_model
if __name__ == '__main__':
ckpt = torch.load("onnx_tse/skim_causal_460_wNoise_IS_tsdr.ckpt", map_location="cpu")
enroll_wav_file = 'onnx_tse/s1/61-70968-0003_5105-28240-0000.wav'
mix_wav_file = 'onnx_tse/mix/61-70968-0003_5105-28240-0000.wav'
enroll_wav = torch.from_numpy(soundfile.read(enroll_wav_file, dtype="float32")[0].reshape(1, -1))
mix_wav = torch.from_numpy(soundfile.read(mix_wav_file, dtype="float32")[0].reshape(1, -1))
speaker_net = DemoSpeakerNet()
speaker_net.load_state_dict(ckpt["state_dict"], strict=False)
speaker_net.eval()
tse_net = DemoTseNet()
tse_net.load_state_dict(ckpt["state_dict"], strict=False)
tse_net.eval()
tse_net.masker.init_status()
speaker_embedding = speaker_net.get_speaker_embedding(enroll_wav)
stream_enh_wav = tse_net.streaming_inference_chunk(mix_wav, speaker_embedding)
soundfile.write('onnx_tse/enh_stream.wav', stream_enh_wav, 16000)
model = init_model('tse_skim_v0_causal', verbose=False)
model.load_state_dict(ckpt["state_dict"], strict=False) # ignore loss's weight
model.eval()
enh_wav = model.inference(mix_wav, enroll_wav)
enh_wav = enh_wav.detach().cpu().numpy().reshape(-1)
soundfile.write('onnx_tse/enh.wav', enh_wav, 16000)
print('Hello world')
你好
请问 skim_causal_460_wNoise_IS_tsdr.ckpt 和 libri2mix_max_2spk_clean_16k_1c.ckpt 这两个模型有什么不一样吗? 我用Netron 查看这两个模型,发现它们参数的结构都一样,只有loss 函数有一些不一样。skim_causal_460_wNoise_IS_tsdr.ckpt 效果是要好一些吗?使用了更多的数据训练?
谢谢!
Hi
Thank you very much for open-sourcing this project. This is a great open-source project, which is very helpful to me, because I just came into contact with tse recently, but currently there are relatively few tse projects/demo on github.
I first carefully read the source code, and then ran tse/demo_app, using skim_causal_460_wNoise_IS_tsdr.ckpt to do the test, the effect is not very satisfactory.
At the same time, the above model was also used for offline testing. Using the synth data in https://github.com/eeskimez/pse-samples for comparison, it was found that the effect was worse than that of pdcattunet.
Based on your current tse results, is there any possibility to improve the effect? Do you have some suggestions?
Thanks!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.