[CI] Moved from assert_allclose to assert_almost_equal (#200)

This commit is contained in:
Philippe Tillet
2021-08-12 12:00:30 -07:00
committed by GitHub
parent 70e28ff380
commit b120d70a0a
5 changed files with 24 additions and 16 deletions

View File

@@ -79,11 +79,11 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
K = BLOCK_K * SPLIT_K if K is None else K
# allocate/transpose inputs
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = .1*torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
b = .1*torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
a = a.t() if AT else a
b = b.t() if BT else b
# run test
th_c = torch.matmul(a, b)
tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest)
assert triton.testing.allclose(th_c, tt_c)
triton.testing.assert_almost_equal(th_c, tt_c)