Git Product home page Git Product logo

whisper-at's People

Contributors

yuangongnd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

whisper-at's Issues

Support for whisper-large-v3

Hello,

First of all, nice work!

Is it possible to release a checkpoint trained with whisper-large-v3? The reason I'm interested in this is that large-v3 is trained on a new dataset with 5 million hours of audio. I'm interested to see how that scaling will impact whisper-at.

Thank you.

Exception using word_timestamps=True in model.transcribe

Hi there! I was hoping to use whisper's ability to provide timestamps around the audio events your work captures.

I'm currently getting an exception when I pass through a True word_timestamps value to model.transcribe().

Thanks!

import whisper_at as whisper
model = whisper.load_model("small")
result = model.transcribe(audio_path, at_time_res=10, word_timestamps=True)

Traceback

[/usr/local/lib/python3.10/dist-packages/whisper_at/transcribe.py](https://localhost:8080/#) in transcribe(model, audio, verbose, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, condition_on_previous_text, initial_prompt, word_timestamps, prepend_punctuations, append_punctuations, at_time_res, **decode_options)
    344 
    345             if word_timestamps:
--> 346                 add_word_timestamps(
    347                     segments=current_segments,
    348                     model=model,

[/usr/local/lib/python3.10/dist-packages/whisper_at/timing.py](https://localhost:8080/#) in add_word_timestamps(segments, model, tokenizer, mel, num_frames, prepend_punctuations, append_punctuations, **kwargs)
    310 
    311     text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
--> 312     alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
    313     merge_punctuations(alignment, prepend_punctuations, append_punctuations)
    314 

[/usr/local/lib/python3.10/dist-packages/whisper_at/timing.py](https://localhost:8080/#) in find_alignment(model, tokenizer, text_tokens, mel, num_frames, medfilt_width, qk_scale)
    193 
    194     with torch.no_grad():
--> 195         logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
    196         sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
    197         token_probs = sampled_logits.softmax(dim=-1)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.10/dist-packages/whisper_at/model.py](https://localhost:8080/#) in forward(self, mel, tokens)
    271         self, mel: torch.Tensor, tokens: torch.Tensor
    272     ) -> Dict[str, torch.Tensor]:
--> 273         return self.decoder(tokens, self.encoder(mel))
    274 
    275     @property

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.10/dist-packages/whisper_at/model.py](https://localhost:8080/#) in forward(self, x, xa, kv_cache)
    210             + self.positional_embedding[offset : offset + x.shape[-1]]
    211         )
--> 212         x = x.to(xa.dtype)
    213 
    214         for block in self.blocks:

AttributeError: 'tuple' object has no attribute 'dtype'

writing json file error

Hi,

When I try to run the json_writer, it is not possible using your package, as the results dictionary contains tensors which are not JSON serializable.

`from whisper_at.utils import get_writer

json_writer = get_writer("json", ".")
writer_args = {"highlight_words": False, "max_line_count": None, "max_line_width": None}
json_writer(result, "test", writer_args)

TypeError: Object of type Tensor is not JSON serializable`

The simple solution would be to pop the audio_tags entry and to creater a new writer for the audio tags, but any other alternative workaround is appreciated.

Thanks!

invalid for input of size 95904000 ?

in the train stage,it still fail , File "/opt/whisper/whisper-at/src/whisper_at_train/models.py", line 172, in forward
audio_rep = audio_rep.reshape(Bself.n_layer, audio_rep.shape[2], audio_rep.shape[3]) # [B32, 25, 1280]
RuntimeError: shape '[192, 25, 80]' is invalid for input of size 95904000

downloading forms a bad path

In some conditions, the URL passed to _download (in init.py) has the dl=1 GET parameter added which makes the download_target filename to be erroneously set as containing also that (eg: it ends up like "..... .cache\whisper\large-v2_ori.pth?dl=1")
A fix could be to:
from urllib.parse import urlparse
parsed_url = urlparse(url)
And then use parsed_url.path further.. such as instead of:
download_target = os.path.join(root, os.path.basename(url))
use
download_target = os.path.join(root, os.path.basename(parsed_url.path))

Occasional IndexError on empty segments

I ran into this error: IndexError: arrays used as indices must be of integer (or boolean) type - this happens when a segment is empty (same issue as this on openai's whisper). A fix has been implemented on openai's whisper here: openai/whisper#1317

Could you implement the same fix for whisper-at, too? Thanks!

JAX-models

Is transcription with JAX-able models and subsequent labeling with whisper-at possible?
Or does the transcription result need to be from whisper-at? Im wondering because of how fast whisper-jax is, I could use it to transcribe but use your model to label (speech, laughter ect.) afterward.

RuntimeError: torch.cat(): expected a non-empty list of Tensors

Number of Classes is 3
Now load features from /data/sls/scratch/yuangong/whisper-a/feat_as_full/whisper_tiny
Dataset has 31 samples
Using Label Smoothing: 0.0
Using Following Mask: 0 Freq, 0 Time
Using Mix-up with Rate 0.000000
Now Process as-full
Number of Classes is 3
Now load features from /data/sls/scratch/yuangong/whisper-a/feat_as_eval/whisper_tiny
val_loader configuration!!!!!!!! /data/ph/test.json
test in tin!!!!!!!!!!!!!1y
lw_tr_1_8 tiny 4 384

Creating experiment directory: ./exp/test-as-full-whisper-whisper-high-lw_tr_1_8-tiny-5e-5-15-0.75-bs48-ldaFalse-mix0.5-0-10
Now starting training for 30 epochs
running on cuda
Total parameter number is : 3.550 million
Total trainable parameter number is : 3.550 million
The learning rate scheduler starts at 15 epoch with decay rate of 0.750 every 5 epoches
now training with as-full, main metrics: mAP, loss function: BCEWithLogitsLoss(), learning rate scheduler: <torch.optim.lr_scheduler.MultiStepLR object at 0x7f400ed87490>
current #steps=0, #epochs=1
start training...

2024-01-17 08:59:12.731168
current #epochs=1, #steps=0
test shape
torch.Size([48, 4, 25, 384])
start validation
111111111111
11112222211
val_loader is <torch.utils.data.dataloader.DataLoader object at 0x7f400ed878e0>
[] []
Traceback (most recent call last):
File "/opt/whisper/whisper-at/src/whisper_at_train/./run.py", line 162, in
train(audio_model, train_loader, val_loader, args)
File "/opt/whisper/whisper-at/src/whisper_at_train/traintest.py", line 137, in train
stats, valid_loss = validate(audio_model, test_loader, args)
File "/opt/whisper/whisper-at/src/whisper_at_train/traintest.py", line 241, in validate
audio_output = torch.cat(A_predictions)
RuntimeError: torch.cat(): expected a non-empty list of Tensors

image

what's problem with this train stage

'Whisper' object has no attribute 'transcribe_audio'

File "/opt/whisper/whisper-at/src/noise_robust_asr/intermediate_feat_extract/as_full/extract_as_full_whisper_all.py", line 35, in extract_audio
_, audio_rep = mdl.transcribe_audio(wav)
File "/root/miniconda3/envs/whisper-at-new/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in getattr
raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
AttributeError: 'Whisper' object has no attribute 'transcribe_audio'

Missing train & eval data json file

Hi
May you provide us the whole_train_data.json and eval_data.json for reproducing your results using run_as_full_train.sh ?
If the entire json file can't be provided, would you please at least provide the corresponding json file for esc-50, the one you provide extracted features?

Thanks in advance

how to inference with batch?

Hello, thank you for sharing really nice code.

However, I cannot find batch-wise inference codes for transcribing .

(I referred quick start example code in ReadMe)

Is there any batch-wise codes for inference?

best regards

whisper-at does not recognize laughter

I was testing whisper-at on hugging face and I found out that it does not recognize laughter.
I'm specifically looking for a solution to recognize laughter.

Use with a fine-tuned model

Thanks for the great code!
Can you explain if and how it's possible to train a whisper-at model based on a fine-tuned whisper model?
Maybe a more general question, if we have a whisper model (transformers version), what's the process for training a model?
From what I can see, the training dataset isn't in the repo. Where can we find it?

miss the file of balance sample

Hi, Yuan

I'm training with full AS, and encountered some issues. I have attached my training log for reference.

  1. I couldn't find the _*weight.csv mentioned in line 94 of https://github.com/YuanGongND/whisper-at/blob/main/src/whisper_at_train/run.py.

    samples_weight = np.loadtxt(args.data_train[:-5]+'_weight.csv', delimiter=',')
    Could you please provide it to help me reproduce the results?

  2. The growth curve of the mAP in my train log differs significantly from the one you provided. My first epoch shows 0.017 mAP, while yours is 0.26. Does this indicate an error?
    my_large-v2_ori.txt

  3. I would like to know approximately how many GPUs and how much time it would take to run a single experiment.

How to Use Temporal Pooling Layer?

I use “time_pooling = nn.AvgPool2d((60,1))” for whisper large pre-trained model(encoder out size is [batch,1500,1280])as Temporal Pooling Layer, but for 'last_mlp' and 'last_tr' methods cannot achieve the accuracy mentioned in the paper in my test on the ESC50 dataset. So I would like to ask if my settings are correct.
The detailed code is as follows:

def forward(self, x):
outputs = self.model.encoder(x) #large:[bs,1500,1280]
outputs = self.time_pooling(outputs) #large:[bs,25,1280]
outputs = self.time_tr(outputs) #Temporal transformer
outputs = torch.mean(outputs, dim=1) #large:[bs,1280]
logits = self.mlp_layer(outputs) #large:[bs,50]

Another question: What is "Linear Projection" mean? is it "nn.Linear" ? I didn't find this part in the 'models.py' file in your release code

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.