[CODEGEN] Bugfixes with FP32 async copy

This commit is contained in:
Philippe Tillet
2021-02-24 13:36:26 -05:00
parent 11215f0f03
commit 567a1a3d17
3 changed files with 12 additions and 11 deletions

View File

@@ -3,11 +3,12 @@ import triton
import pytest
@pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK",
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
for block in [16, 32, 64]],
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
[(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"]
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]],
)
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
# set seed
torch.random.manual_seed(0)
# create inputs