.
This commit is contained in:
@@ -326,6 +326,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
print(ref_dk, tri_dk)
|
||||
print(ref_dq, tri_dq)
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
|
Reference in New Issue
Block a user