This commit is contained in:
Phil Tillet
2023-01-03 18:34:05 -08:00
parent 8df1fa5e5b
commit 645fa5c1cd
2 changed files with 7 additions and 9 deletions

View File

@@ -326,8 +326,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
print(ref_dk, tri_dk)
print(ref_dq, tri_dq)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4