Fixed CoW RuntimeError in DecodingTask.run() (#240)

This commit is contained in:
Corentin Jemine
2022-10-04 17:49:31 +02:00
committed by GitHub
parent 02b74308ff
commit 9e653bd0ea

View File

@@ -615,7 +615,7 @@ class DecodingTask:
n_audio: int = mel.shape[0]
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).expand(n_audio, -1)
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)