@@ -15,7 +15,7 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
|
Reference in New Issue
Block a user