Git Product home page Git Product logo

fast_bss_eval's Introduction

fast_bss_eval

PyPI version Documentation Status black tests codecov

Do you have a zillion BSS audio files to process and it is taking days ? Is your simulation never ending ?

Fear no more! fast_bss_eval is here to help you!

fast_bss_eval is a fast implementation of the bss_eval metrics for the evaluation of blind source separation. Our implementation of the bss_eval metrics has the following advantages compared to other existing ones.

  • seamlessly works with both numpy arrays and pytorch tensors
  • very fast
  • can be even faster by using an iterative solver (add use_cg_iter=10 option to the function call)
  • differentiable via pytorch
  • can run on GPU via pytorch

Author

Quick Start

Install

# from pypi
pip install fast-bss-eval

# or from source
git clone https://github.com/fakufaku/fast_bss_eval
cd fast_bss_eval
pip install -e .

Use

Assuming you have multichannel signals for the estmated and reference sources stored in wav format files names my_estimate_file.wav and my_reference_file.wav, respectively, you can quickly evaluate the bss_eval metrics as follows.

from scipy.io import wavfile
import fast_bss_eval

# open the files, we assume the sampling rate is known
# to be the same
fs, ref = wavfile.read("my_reference_file.wav")
_, est = wavfile.read("my_estimate_file.wav")

# compute the metrics
sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref.T, est.T)

Benchmark

This package is significantly faster than other packages that also allow to compute bss_eval metrics such as mir_eval or sigsep/bsseval. We did a benchmark using numpy/torch, single/double precision floating point arithmetic (fp32/fp64), and using either Gaussian elimination or a conjugate gradient descent (solve/CGD10).

Citation

If you use this package in your own research, please cite our paper describing it.

@misc{scheibler_sdr_2021,
  title={SDR --- Medium Rare with Fast Computations},
  author={Robin Scheibler},
  year={2021},
  eprint={2110.06440},
  archivePrefix={arXiv},
  primaryClass={eess.AS}
}

License

2021 (c) Robin Scheibler, LINE Corporation

All of this code is released under MIT License with the exception of fast_bss_eval/torch/hungarian.py which is under 3-clause BSD License.

fast_bss_eval's People

Contributors

fakufaku avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

fast_bss_eval's Issues

How to evaluate SIR and SDR for mono wav file

Hello.

I have a question about how to evaluate SIR and SDR for mono wav file.
How do I evaluate SIR and SDR for mono wav files?

I have the following mono wav files.

  • Mixed voice and noise audio
  • Voice audio (ref.wav)
  • Noise audio
  • Inference file (est.wav)

The length of the wav file is 4 seconds. The sampling frequency is 16k Hz.
I calculated the SIR of the mono wav file and it was Inf.
As I asked in Issue #12, the SIR was Inf for the following code.

from scipy.io import wavfile
import numpy as np
import fast_bss_eval

_, ref = wavfile.read("./data/ref.wav")
_, est = wavfile.read("./data/est.wav")

ref = ref[None, ...]
est = est[None, ...]

# compute the metrics
sdr, sir, sar = fast_bss_eval.bss_eval_sources(ref, est, compute_permutation=False)

print('sdr:', sdr)
print('sir:', sir)
print('sar:', sar)

sdr: 14.188884277900977
sir: inf
sar: 14.18888427790095

However, I would like to evaluate the SIR with a mono wav file.
To avoid the SIR to be Inf, I divided the wav file into 4 parts. Is the following code able to evaluate SIR and SDR correctly?

from scipy.io import wavfile
import numpy as np
import fast_bss_eval

ref = np.zeros((4, 16000))
est = np.zeros((4, 16000))

_, ref_temp = wavfile.read("./data/ref1.wav")
_, est_temp = wavfile.read("./data/est1.wav")
ref[0] = ref_temp
est[0] = est_temp

_, ref_temp = wavfile.read("./data/ref2.wav")
_, est_temp = wavfile.read("./data/est2.wav")
ref[1] = ref_temp
est[1] = est_temp

_, ref_temp = wavfile.read("./data/ref3.wav")
_, est_temp = wavfile.read("./data/est3.wav")
ref[2] = ref_temp
est[2] = est_temp

_, ref_temp = wavfile.read("./data/ref4.wav")
_, est_temp = wavfile.read("./data/est4.wav")
ref[3] = ref_temp
est[3] = est_temp

# compute the metrics
sdr, sir, sar = fast_bss_eval.bss_eval_sources(ref, est, compute_permutation=False)

print('sdr:', sdr.mean())
print('sir:', sir.mean())
print('sar:', sar.mean())

sdr: 16.156123610321156
sir: 28.957842593289392
sar: 16.444840346137177

What signals are needed for each channel of ref and est?
Best regards.

Application Extension: How to use this metric in Short-Time Fourier Transform (STFT) domain?

Lots of BSS work make efforts to find coefficients in Short-Time Fourier Transform (STFT) domain, and multiply them with signal in the same domain. I think that this metric is evaluated in time domain, and wondering how can use it directly in STFT domain (maybe need to find out the frequency mapping of distortion filter?). If it worked, this metric can be directly extend to these neural network in STFT domain?
Thank for your brilliant creative work!

Compatibility problem with torch >= 1.8.0 when torch_complex package is not installed

Hello,
I noticed that when trying to use the package (version 0.1.3), I get some compatibility issues when using torch.Tensor inputs for the method bss_eval_sources because I did not have the torch_complex package installed. However, the torch_complex package shouldn't be required in this case since I use torch 1.10.2.

This happens because in the __init__.py file, the variable has_torch is not set to True

try:
    import torch as pt
    has_torch = True

    from . import torch as torch     # --> this line fails
    from .torch import sdr_pit_loss, si_sdr_pit_loss   
except ImportError:
    has_torch = False

    # dummy pytorch module
    class pt:
        class Tensor:
            def __init__(self):
                pass

    # dummy torch submodule
    class torch:
        bss_eval_sources = None
        sdr = None
        sdr_loss = None

from . import numpy as numpy

Apparently this happens because the line that fails tries to import the file torch/compatibility.py :

try:
    from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
    from distutils.version import LooseVersion as Version

from torch_complex import ComplexTensor # --> this line causes the problem when torch_complex is not installed 

import torch

is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")

if not is_torch_1_8_plus:
    try:
        import torch_complex
    except ImportError:
        raise ImportError(
            "When using torch<=1.7, the package torch_complex is required."
            " Install it as `pip install torch_complex`"
        )

If I understand correctly, the fix would simply be to do the following :

try:
    from packaging.version import Version
except [ImportError, ModuleNotFoundError]:
    from distutils.version import LooseVersion as Version

import torch

is_torch_1_8_plus = Version(torch.__version__) >= Version("1.8.0")

if not is_torch_1_8_plus:
    try:
        from torch_complex import ComplexTensor 
    except ImportError:
        raise ImportError(
            "When using torch<=1.7, the package torch_complex is required."
            " Install it as `pip install torch_complex`"
        )

RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0

Hi,

I'm trying to use the torch.bss_eval_sources and I'm getting the error:

fast_bss_eval/torch/helpers.py", line 142, in _linear_sum_assignment_with_inf 
  m = values.min()
RuntimeError: min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

Any pointers on that?

Failing for my specific waveforms

Hello,

I have some waveforms on which the evaluation fails. The shape is [64, 2, 44100] (batch size 64, 2 channels, 44100 samples i.e. 1 second of music @ 44100 Hz sample rate)

I've attached the tensors (saved as .pt files), and my test looks like this:

"""
sevagh testing a random waveform
"""
import numpy as np
import torch
import pytest
from mir_eval.separation import bss_eval_sources
import fast_bss_eval

pred = torch.load("/waveform.pt", map_location=torch.device("cuda"))
target = torch.load("/waveform.pt", map_location=torch.device("cuda"))

print(pred)

if __name__ == "__main__":

    print(pred.shape, target.shape)
    print(pred.dtype, target.dtype)
    print(pred.device, target.device)
    print()

    sdr, sir, sar, perm = fast_bss_eval.torch.bss_eval_sources(target, pred, use_cg_iter=10)

    print(sdr, sdr.dtype)
    print()

The error output is:

root@0f6121f288b6:~/fast_bss_eval# python tests/test_sevagh_case.py
tensor([[[-0.1661, -0.1645, -0.1626,  ..., -0.1039, -0.1029, -0.1017],
         [-0.1661, -0.1644, -0.1626,  ..., -0.1039, -0.1029, -0.1018]],

        [[ 0.0318,  0.0322,  0.0325,  ...,  0.0085,  0.0079,  0.0073],
         [ 0.0298,  0.0304,  0.0307,  ...,  0.0115,  0.0108,  0.0102]],

        [[ 0.1203,  0.1211,  0.1217,  ...,  0.0269,  0.0282,  0.0295],
         [ 0.1187,  0.1194,  0.1201,  ...,  0.0259,  0.0272,  0.0285]],

        ...,

        [[ 0.0068,  0.0070,  0.0072,  ...,  0.0069,  0.0069,  0.0070],
         [ 0.0067,  0.0069,  0.0071,  ...,  0.0068,  0.0068,  0.0069]],

        [[-0.0176, -0.0196, -0.0215,  ...,  0.0601,  0.0605,  0.0608],
         [-0.0176, -0.0196, -0.0215,  ...,  0.0601,  0.0605,  0.0608]],

        [[ 0.0084,  0.0076,  0.0069,  ...,  0.0103,  0.0109,  0.0116],
         [ 0.0045,  0.0032,  0.0020,  ...,  0.0170,  0.0179,  0.0188]]],
       device='cuda:0')
tensor([[[-0.1661, -0.1645, -0.1626,  ..., -0.1039, -0.1029, -0.1017],
         [-0.1661, -0.1644, -0.1626,  ..., -0.1039, -0.1029, -0.1018]],

        [[ 0.0318,  0.0322,  0.0325,  ...,  0.0085,  0.0079,  0.0073],
         [ 0.0298,  0.0304,  0.0307,  ...,  0.0115,  0.0108,  0.0102]],

        [[ 0.1203,  0.1211,  0.1217,  ...,  0.0269,  0.0282,  0.0295],
         [ 0.1187,  0.1194,  0.1201,  ...,  0.0259,  0.0272,  0.0285]],

        ...,

        [[ 0.0068,  0.0070,  0.0072,  ...,  0.0069,  0.0069,  0.0070],
         [ 0.0067,  0.0069,  0.0071,  ...,  0.0068,  0.0068,  0.0069]],

        [[-0.0176, -0.0196, -0.0215,  ...,  0.0601,  0.0605,  0.0608],
         [-0.0176, -0.0196, -0.0215,  ...,  0.0601,  0.0605,  0.0608]],

        [[ 0.0084,  0.0076,  0.0069,  ...,  0.0103,  0.0109,  0.0116],
         [ 0.0045,  0.0032,  0.0020,  ...,  0.0170,  0.0179,  0.0188]]],
       device='cuda:0')
torch.Size([64, 2, 44100]) torch.Size([64, 2, 44100])
torch.float32 torch.float32
cuda:0 cuda:0

Traceback (most recent call last):
  File "tests/test_sevagh_case.py", line 23, in <module>
    sdr, sir, sar, perm = fast_bss_eval.torch.bss_eval_sources(target, pred, use_cg_iter=10)
  File "/root/fast_bss_eval/fast_bss_eval/torch/metrics.py", line 654, in bss_eval_sources
    coh_sdr, coh_sar = square_cosine_metrics(
  File "/root/fast_bss_eval/fast_bss_eval/torch/metrics.py", line 565, in square_cosine_metrics
    sol = block_toeplitz_conjugate_gradient(acf, xcorr, n_iter=use_cg_iter, x=x0)
  File "/root/fast_bss_eval/fast_bss_eval/torch/cgd.py", line 403, in block_toeplitz_conjugate_gradient
    precond = BlockCirculantPreconditionerOperator(acf)
  File "/root/fast_bss_eval/fast_bss_eval/torch/cgd.py", line 184, in __init__
    self.C = inv(C)
  File "/root/fast_bss_eval/fast_bss_eval/torch/compatibility.py", line 145, in inv
    return torch.linalg.inv(*args, **kwargs)
torch._C._LinAlgError: linalg.inv: (Batch element 770): The diagonal element 2 is zero, the inversion could not be completed because the input matrix is singular.

The waveforms are normal stuff (extracted segments from MUSDB18-HQ). I've attached waveform.pt ain a zip file. Can you help me figure it out? Thanks in advance.
sevagh-bss-eval-error.zip

Compability to "bsseval_sources_version"

Hey there,
thanks for your great and super useful work of yours!

I just stumbled upon a small problem, that is related to the windowing question in #10, where I'd like to use your package as a replacement to museval.
When evaluting using fast_bss_eval.bss_eval_sources(ref,est) and museval.evaluate(ref,est), we achieve different values for the SDR.
After some checking, I found that the main problem is the bsseval_sources_version parameter, which is set to False for museval.evaluate, but produces the exactly same results as fast_bss_eval.bss_eval_sources, if set to True.

My question is: Is there some kind of parameter or any suggestion how to change the output of fast_bss_eval accordingly, such that the results are somewhat similar?

Thanks in advance!

give nan results when use pytorch version for some input

^^ Hello, I found fast_bss_eval (version 0.1.0) sometimes gives NaN results.

The test code:

import numpy as np
import torch
from mir_eval.separation import bss_eval_sources
import fast_bss_eval

x = np.load('debug.npz')
preds = torch.tensor(x['preds'])
target = torch.tensor(x['target'])
print(preds.shape, target.shape)

sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(target, preds)
print(sdr)

sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(target.numpy(), preds.numpy())
print(sdr)

sdr,_,_,_ = bss_eval_sources(target.numpy(), preds.numpy(), False)
print(sdr)

the results:

torch.Size([2, 64000]) torch.Size([2, 64000])
tensor([-2.6815,     nan])
[-2.6815615 44.575493 ]
[-2.68156071 44.58523729]

The data and the debug code are all zipped in the debug.zip

Results of the SIR Evaluation

Hello.
I have a question about the SIR evaluation.

wav.zip
Attached is the wav file we used for the evaluation.
voice_ref.wav: voice only file
noise_ref.wav: noise only file
mix.wav: file with voice and noise mixed
eval.wav:voice estimated for mix.wav

I evaluated the SIR of eval.wav using voice_ref.wav and noise_ref.wav as reference signals. Then, the SIR was 0.659 dB.
Next, I evaluated the SIR of mix.wav using voice_ref.wav and noise_ref.wav as reference signals. The SIR was then 3.864 dB.

I had understood that as the noise decreased, the SIR value would increase. However, this is the opposite result.
Why does this happen? Is the evaluation process not good?

Best regards.

Any plan for supporting windowing method?

Some previous libraries like museval (https://github.com/sigsep/sigsep-mus-eval/blob/master/museval/metrics.py) or mir-eval(https://github.com/craffel/mir_eval/blob/master/mir_eval/separation.py) have parameter named 'window'.
It split large size data into multiple chunks and calculate metrics(like sdr) and aggregate them.

I tried fast_bss_eval simply replacing museval.evaluate() into fast_bss_eval.bss_eval_sources(),
but facing out of memory error (requiring 800GB memory).
If this library provide windowing methods to control the memory usage, it would be great and become more easy to use.

Anyway, thanks for your awesome implementation!

ValueError: einstein sum subscripts string contains too many subscripts for operand 0

Hello.
I ran the following Python code with the sample code as a reference.

from scipy.io import wavfile  
import fast_bss_eval  

fs, ref = wavfile.read("./test/ref.wav")  
_,  est = wavfile.read("./test/est.wav")  

sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref.T, est.T)  

However, the following errors occurred

(fast_bss_eval) C:\Users\4020737\Documents\git\FastBssEval>python eval.py
C:\Users\4020737\Documents\git\FastBssEval\eval.py:4: WavFileWarning: Chunk (non-data) not understood, skipping it.
  fs, ref = wavfile.read("./test/ref.wav")
C:\Users\4020737\Documents\git\FastBssEval\eval.py:5: WavFileWarning: Chunk (non-data) not understood, skipping it.
  _,  est = wavfile.read("./test/est.wav")
Traceback (most recent call last):
  File "C:\Users\4020737\Documents\git\FastBssEval\eval.py", line 8, in <module>
    sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref.T, est.T)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\__init__.py", line 365, in bss_eval_sources
    return _dispatch_backend(
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\__init__.py", line 304, in _dispatch_backend
    return f_numpy(*args, **kwargs)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\metrics.py", line 657, in bss_eval_sources
    coh_sdr, coh_sar = square_cosine_metrics(
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\metrics.py", line 522, in square_cosine_metrics
    acf, xcorr = compute_stats_2(ref, est, length=filter_length)
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\fast_bss_eval\numpy\metrics.py", line 173, in compute_stats_2
    prod = np.einsum("...cn,...dn->...ncd", X, X.conj())
  File "<__array_function__ internals>", line 180, in einsum
  File "C:\Users\4020737\Anaconda3\envs\fast_bss_eval\lib\site-packages\numpy\core\einsumfunc.py", line 1359, in einsum
    return c_einsum(*operands, **kwargs)
ValueError: einstein sum subscripts string contains too many subscripts for operand 0

I thought the wav file was not good and modified the code as follows, but result was the same.

from scipy.io import wavfile
import numpy as np
import fast_bss_eval

ref = np.random.randint(1000, 10000, 160000)
est = np.random.randint(1000, 10000, 160000)

#compute the metrics
sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(ref.T, est.T)

The list of libraries in my environment is as follows.

# packages in environment at C:\Users\4020737\Anaconda3\envs\fast_bss_eval:
#
# Name                    Version                   Build  Channel
blas                      1.0                         mkl
bzip2                     1.0.8                he774522_0
ca-certificates           2022.4.26            haa95532_0
certifi                   2022.5.18.1     py310haa95532_0
fast-bss-eval             0.1.4                      py_0    wietsedv
icc_rt                    2019.0.0             h0cc432a_1
intel-openmp              2021.4.0          haa95532_3556
libffi                    3.4.2                hd77b12b_4
mkl                       2021.4.0           haa95532_640
mkl-service               2.4.0           py310h2bbff1b_0
mkl_fft                   1.3.1           py310ha0764ea_0
mkl_random                1.2.2           py310h4ed8f06_0
numpy                     1.22.3          py310h6d2d95c_0
numpy-base                1.22.3          py310h206c741_0
openssl                   1.1.1o               h2bbff1b_0
pip                       21.2.4          py310haa95532_0
python                    3.10.4               hbb2ffb3_0
scipy                     1.7.3           py310h6d2d95c_0
setuptools                61.2.0          py310haa95532_0
six                       1.16.0             pyhd3eb1b0_1
sqlite                    3.38.3               h2bbff1b_0
tk                        8.6.12               h2bbff1b_0
tzdata                    2022a                hda174b7_0
vc                        14.2                 h21ff451_1
vs2015_runtime            14.27.29016          h5e58377_2
wheel                     0.37.1             pyhd3eb1b0_0
wincertstore              0.2             py310haa95532_2
xz                        5.2.5                h8cc25b3_1
zlib                      1.2.12               h8cc25b3_2

Best regards.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.