This commit is contained in:
Phil Tillet
2023-01-02 23:13:12 -08:00
parent 05920e0b8b
commit 5c01c567b9
3 changed files with 81 additions and 82 deletions

View File

@@ -326,6 +326,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
triton.testing.assert_almost_equal(ref_dv, tri_dv)
triton.testing.assert_almost_equal(ref_dk, tri_dk)
triton.testing.assert_almost_equal(ref_dq, tri_dq)
print(ref_dk, tri_dk)
print(ref_dq, tri_dq)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4