Git Product home page Git Product logo

frechet-audio-distance's Introduction

Frechet Audio Distance in PyTorch

A lightweight library of Frechet Audio Distance (FAD) calculation.

Currently, we support:

Installation

pip install frechet_audio_distance

Example

For FAD:

from frechet_audio_distance import FrechetAudioDistance

# to use `vggish`
frechet = FrechetAudioDistance(
    model_name="vggish",
    sample_rate=16000,
    use_pca=False, 
    use_activation=False,
    verbose=False
)
# to use `PANN`
frechet = FrechetAudioDistance(
    model_name="pann",
    sample_rate=16000,
    use_pca=False, 
    use_activation=False,
    verbose=False
)
# to use `CLAP`
frechet = FrechetAudioDistance(
    model_name="clap",
    sample_rate=48000,
    submodel_name="630k-audioset",  # for CLAP only
    verbose=False,
    enable_fusion=False,            # for CLAP only
)
# to use `EnCodec`
frechet = FrechetAudioDistance(
    model_name="encodec",
    sample_rate=48000,
    channels=2,
    verbose=False,
)

fad_score = frechet.score(
    "/path/to/background/set", 
    "/path/to/eval/set", 
    dtype="float32"
)

You can also have a look at this notebook for a better understanding of how each model is used.

For CLAP score:

from frechet_audio_distance import CLAPScore

clap = CLAPScore(
    submodel_name="630k-audioset",
    verbose=True,
    enable_fusion=False,
)

clap_score = clap.score(
    text_path="./text1/text.csv",
    audio_dir="./audio1",
    text_column="caption",
)

For more info, kindly refer to this notebook.

Save pre-computed embeddings

When computing the Frechet Audio Distance, you can choose to save the embeddings for future use.

This capability not only ensures consistency across evaluations but can also significantly reduce computation time, especially if you're evaluating multiple times using the same dataset.

# Specify the paths to your saved embeddings
background_embds_path = "/path/to/saved/background/embeddings.npy"
eval_embds_path = "/path/to/saved/eval/embeddings.npy"

# Compute FAD score while reusing the saved embeddings (or saving new ones if paths are provided and embeddings don't exist yet)
fad_score = frechet.score(
    "/path/to/background/set",
    "/path/to/eval/set",
    background_embds_path=background_embds_path,
    eval_embds_path=eval_embds_path,
    dtype="float32"
)

Result validation

Test 1: Distorted sine waves on vggish (as provided here) [notes]

FAD scores comparison w.r.t. to original implementation in google-research/frechet-audio-distance

baseline vs test1 baseline vs test2
google-research 12.4375 4.7680
frechet_audio_distance 12.7398 4.9815

Test 2: Distorted sine waves on PANN

baseline vs test1 baseline vs test2
frechet_audio_distance 0.000465 0.00008594

To contribute

Contributions are welcomed! Kindly raise a PR and ensure that all CI checks are passed.

NOTE: For now, the CI only checks for vggish as PANN takes a long time to download.

References

VGGish in PyTorch: https://github.com/harritaylor/torchvggish

Frechet distance implementation: https://github.com/mseitzer/pytorch-fid

Frechet Audio Distance paper: https://arxiv.org/abs/1812.08466

PANN paper: https://arxiv.org/abs/1912.10211

frechet-audio-distance's People

Contributors

balintlaczko avatar gudgud96 avatar ivanlmh avatar mcomunita avatar thibaultcastells avatar zhvng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

frechet-audio-distance's Issues

Does the requirement for transformers<=4.30 necessary?

Thanks for the great work!

I'm just wondering if the transformers version <=4.30 should be necessary or not.
This is because there are newer audio models (e.g. generative model or other models like Encodec in transformers) that are only available in newer versions of transformers.

Is there potential for conflict between the newer version of transformers and your implementation?

Setup CI

As we have more and more pull requests coming in, I need to setup proper CI for testing >.<

The unit test folders are in place, I just need to prepare Github Actions for automated CI.

Issue with 44.1k samples

Hi there,

Thanks for making this!

I have some issues when calculating the score for wav's with 44.1k samples. The error that is trown is:

exception thrown, Input signal length=2 is too small to resample from 44100->16000

I've tried to replicate the loading process as done in this repo.

wav, sr = sf.read(file_name, dtype='int16')

But this only returns an array with mostly zero's.

The wav's are not corrupted as far as I can see.

Handling float32 wav files

Floating-point PCM WAV files are currently not handled correctly. Such files are not very common, but are created when passing float32 numpy arrays to scipy.io.wav.write() or torchaudio.save().

What currently happens on loading is this:

  1. wav_data, sr = sf.read(fname, dtype='int16')
    This reads the floating point samples and then rounds them to the nearest integer (so wav_data will usually only contain -1, 0, and +1 samples), in line with the documentation of soundfile.read().
  2. assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
    As far as I can tell, this can never happen, because the dtype is enforced in the previous line.
  3. wav_data = wav_data / 32768.0  # Convert to [-1.0, +1.0]
    This converts the samples to -1/32768, 0 and +1/32768.

One solution would be to not pass a dtype argument to soundfile.read() and then convert whatever comes back, something like:

wav_data, sr = sf.read(fname)
if np.issubdtype(data.dtype, np.floating):
    wav_data = np.asarray(wav_data, dtype=np.float32)
else:
    wav_data = np.divide(wav_data, np.iinfo(wav_data.dtype).max, dtype=np.float32)

Another solution would be to pass dtype=np.float32 to soundfile.read(). Was there a specific reason against this route? With an example OGG file, I've seen that dtype=np.float32 produced values between -1.2 and +1.4, while dtype=np.int16 scaled it by 32767 and clipped too high and too low values. Not sure if the latter is preferable in any case.

Vggish is not available on PyTorch hub

Hello, I was trying to use the Vggish model, but got the following error. It seems Vggish is not available on PyTorch Hub.

UserWarning: You are about to download and run code from an untrusted repository. In a future release, this won't be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, or load(..., trust_repo=True), which will assume that the prompt is to be answered with 'yes'. You can also use load(..., trust_repo='check') which will only prompt for confirmation if the repo is not already trusted. This will eventually be the default behaviour
warnings.warn(
Downloading: "https://github.com/harritaylor/torchvggish/zipball/master" to /Users/zihaohe/.cache/torch/hub/master.zip
Traceback (most recent call last):
File "", line 1, in
File "/Users/zihaohe/miniconda3/lib/python3.9/site-packages/torch/hub.py", line 555, in load
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
File "/Users/zihaohe/miniconda3/lib/python3.9/site-packages/torch/hub.py", line 230, in _get_cache_or_reload
download_url_to_file(url, cached_file, progress=False)
File "/Users/zihaohe/miniconda3/lib/python3.9/site-packages/torch/hub.py", line 611, in download_url_to_file
u = urlopen(req)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 214, in urlopen
return opener.open(url, data, timeout)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 523, in open
response = meth(req, response)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 632, in http_response
response = self.parent.error(
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 555, in error
result = self._call_chain(*args)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 494, in _call_chain
result = func(*args)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 747, in http_error_302
return self.parent.open(new, timeout=req.timeout)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 523, in open
response = meth(req, response)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 632, in http_response
response = self.parent.error(
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 561, in error
return self._call_chain(*args)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 494, in _call_chain
result = func(*args)
File "/Users/zihaohe/miniconda3/lib/python3.9/urllib/request.py", line 641, in http_error_default
raise HTTPError(req.full_url, code, msg, hdrs, fp)
urllib.error.HTTPError: HTTP Error 503: Service Unavailable

[chore] Proper CI needed

We are in dire need of proper CI! Given the amount of content we have now in this repo, and the importance of correctness in calculating each score, CI is crucial for being a safety net for further deployment and iterations.

We do have extensive test examples in notebooks, but the idea is to ensure that everything is right and things don't break during PR merge.

Right now we have one (only one...) which is the VGGish unit test. This can be replicated for other embedding types. The only blocker I faced previously is that model downloading is extremely slow for some models (e.g. PANN), and CI would run forever. But, I would believe that well-known packages like EnCodec and CLAP should do fine.

Replicate scores for major papers using Frechet Audio Distance

As mentioned in fcaspe/ddx7#1, I am unable to replicate the FAD score to a satisfactory level yet as reported in the paper.

Need further investigation on whether the diff is due to inherent implementation diffs compared to the Google version, or diffs outside of FAD calculation. Hence I decide to look into some major works to do a more detailed benchmark of the FAD scores reported VS calculated here. Candidates to be listed (will start with DDX7), paper suggestions are welcomed.

Improvement suggestions

Hello,
Thank you for providing this tool! It's great!
If you are interested in improvement suggestions, I would like to suggest the following:

  • When using vggish the model checkpoint is downloaded automatically, but it is not the case with pann.
    Therefore, I get the following error:
    FileNotFoundError: [Errno 2] No such file or directory: 'mypath/.cache/torch/hub/Cnn14_16k_mAP%3D0.438.pth'
    If this is intended, I would suggest to update the readme to explain how to add this model.
  • in the readme, it would be nice to have an indicator of how long it takes to generate the FAD for each method
  • in the readme, more examples of how to use the library would be great (I had to read the code to understand how I can reuse an embedding multiple times in different comparisons)
  • when computing the FAD between a data folder and itself, I do not get exactly 0 (I get a near zero value of -1.1368683772161603e-13 instead). I think it would be better to get exactly 0 in this case.

Again, thank you for your work!

Include more models into FAD score calculation

Recently, there are a few papers that calculate FAD score using alternatives of VGG-ish model, such as:

  • MusicLM - uses (1) Trill2 (Shor et al., 2020) and (2) VGGish3 (Hershey et al., 2017) trained on YouTube-8M audio event dataset (Abu-El-Haija et al., 2016).

  • AudioLDM - uses PANN (Kong et al., 2020b)

A clear list of models to be supported is to be determined depending on the demand of the research community, but we should try to support Trill2 and PANN first. Basically, we need to:

(i) source for open-source PyTorch checkpoints for these models

(ii) refactor code to abstract out audio embedding calculation

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.