allowing nonzero initial temperature

This commit is contained in:
Jong Wook Kim
2022-09-29 18:05:12 -07:00
parent 30dc5c581b
commit 7cb4cc21bf
2 changed files with 26 additions and 30 deletions

View File

@@ -94,7 +94,7 @@ class DecodingOptions:
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 0.0 # the initial timestamp cannot be later than this
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
# implementation details
fp16: bool = True # use fp16 for most of the calculation

View File

@@ -92,41 +92,37 @@ def transcribe(
if verbose is not None:
print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
mel = mel.unsqueeze(0)
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]:
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
kwargs = {**decode_options}
t = temperatures[0]
if t == 0:
best_of = kwargs.pop("best_of", None)
else:
best_of = kwargs.get("best_of", None)
decode_result = None
options = DecodingOptions(**kwargs, temperature=t)
results = model.decode(segment, options)
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
kwargs.pop("beam_size", None) # no beam search for t > 0
kwargs.pop("patience", None) # no patience for t > 0
kwargs["best_of"] = best_of # enable best_of for t > 0
for t in temperatures[1:]:
needs_fallback = [
compression_ratio_threshold is not None
and result.compression_ratio > compression_ratio_threshold
or logprob_threshold is not None
and result.avg_logprob < logprob_threshold
for result in results
]
if any(needs_fallback):
options = DecodingOptions(**kwargs, temperature=t)
retries = model.decode(segment[needs_fallback], options)
for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
results[original_index] = retries[retry_index]
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
return results
needs_fallback = False
if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold:
needs_fallback = True # too repetitive
if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
needs_fallback = True # average log probability is too low
if not needs_fallback:
break
return decode_result
seek = 0
input_stride = exact_div(
@@ -175,11 +171,11 @@ def transcribe(
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype)
segment_duration = segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result = decode_with_fallback(segment)[0]
result: DecodingResult = decode_with_fallback(segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None: