From 53807677fe8bc892fe0fff970e8b24ddd3848596 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Fri, 30 Dec 2022 01:53:06 -0700 Subject: [PATCH] MultiHeadAttention to return qk as well --- whisper/model.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/whisper/model.py b/whisper/model.py index ca3928e..be73a4a 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -82,8 +82,8 @@ class MultiHeadAttention(nn.Module): k = kv_cache[self.key] v = kv_cache[self.value] - wv = self.qkv_attention(q, k, v, mask) - return self.out(wv) + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): n_batch, n_ctx, n_state = q.shape @@ -95,9 +95,10 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() - w = F.softmax(qk.float(), dim=-1).to(q.dtype) - return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) + w = F.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() class ResidualAttentionBlock(nn.Module): @@ -121,9 +122,9 @@ class ResidualAttentionBlock(nn.Module): mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None, ): - x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] if self.cross_attn: - x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] x = x + self.mlp(self.mlp_ln(x)) return x