allowing nonzero initial temperature
This commit is contained in:
@@ -94,7 +94,7 @@ class DecodingOptions:
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 0.0 # the initial timestamp cannot be later than this
|
||||
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
|
@@ -92,41 +92,37 @@ def transcribe(
|
||||
if verbose is not None:
|
||||
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
|
||||
|
||||
mel = mel.unsqueeze(0)
|
||||
language = decode_options["language"]
|
||||
task = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]:
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
kwargs = {**decode_options}
|
||||
t = temperatures[0]
|
||||
if t == 0:
|
||||
best_of = kwargs.pop("best_of", None)
|
||||
else:
|
||||
best_of = kwargs.get("best_of", None)
|
||||
decode_result = None
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
results = model.decode(segment, options)
|
||||
for t in temperatures:
|
||||
kwargs = {**decode_options}
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
kwargs.pop("beam_size", None) # no beam search for t > 0
|
||||
kwargs.pop("patience", None) # no patience for t > 0
|
||||
kwargs["best_of"] = best_of # enable best_of for t > 0
|
||||
for t in temperatures[1:]:
|
||||
needs_fallback = [
|
||||
compression_ratio_threshold is not None
|
||||
and result.compression_ratio > compression_ratio_threshold
|
||||
or logprob_threshold is not None
|
||||
and result.avg_logprob < logprob_threshold
|
||||
for result in results
|
||||
]
|
||||
if any(needs_fallback):
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
retries = model.decode(segment[needs_fallback], options)
|
||||
for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
|
||||
results[original_index] = retries[retry_index]
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
return results
|
||||
needs_fallback = False
|
||||
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
|
||||
needs_fallback = True # too repetitive
|
||||
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
|
||||
needs_fallback = True # average log probability is too low
|
||||
|
||||
if not needs_fallback:
|
||||
break
|
||||
|
||||
return decode_result
|
||||
|
||||
seek = 0
|
||||
input_stride = exact_div(
|
||||
@@ -175,11 +171,11 @@ def transcribe(
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
while seek < num_frames:
|
||||
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
|
||||
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result = decode_with_fallback(segment)[0]
|
||||
result: DecodingResult = decode_with_fallback(segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
|
Reference in New Issue
Block a user