Hi there! I was hoping to use whisper's ability to provide timestamps around the audio events your work captures.
I'm currently getting an exception when I pass through a True word_timestamps value to model.transcribe().
import whisper_at as whisper
model = whisper.load_model("small")
result = model.transcribe(audio_path, at_time_res=10, word_timestamps=True)
[/usr/local/lib/python3.10/dist-packages/whisper_at/transcribe.py](https://localhost:8080/#) in transcribe(model, audio, verbose, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, condition_on_previous_text, initial_prompt, word_timestamps, prepend_punctuations, append_punctuations, at_time_res, **decode_options)
344
345 if word_timestamps:
--> 346 add_word_timestamps(
347 segments=current_segments,
348 model=model,
[/usr/local/lib/python3.10/dist-packages/whisper_at/timing.py](https://localhost:8080/#) in add_word_timestamps(segments, model, tokenizer, mel, num_frames, prepend_punctuations, append_punctuations, **kwargs)
310
311 text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
--> 312 alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
313 merge_punctuations(alignment, prepend_punctuations, append_punctuations)
314
[/usr/local/lib/python3.10/dist-packages/whisper_at/timing.py](https://localhost:8080/#) in find_alignment(model, tokenizer, text_tokens, mel, num_frames, medfilt_width, qk_scale)
193
194 with torch.no_grad():
--> 195 logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
196 sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
197 token_probs = sampled_logits.softmax(dim=-1)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
[/usr/local/lib/python3.10/dist-packages/whisper_at/model.py](https://localhost:8080/#) in forward(self, mel, tokens)
271 self, mel: torch.Tensor, tokens: torch.Tensor
272 ) -> Dict[str, torch.Tensor]:
--> 273 return self.decoder(tokens, self.encoder(mel))
274
275 @property
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
[/usr/local/lib/python3.10/dist-packages/whisper_at/model.py](https://localhost:8080/#) in forward(self, x, xa, kv_cache)
210 + self.positional_embedding[offset : offset + x.shape[-1]]
211 )
--> 212 x = x.to(xa.dtype)
213
214 for block in self.blocks:
AttributeError: 'tuple' object has no attribute 'dtype'