[CODEGEN] Bugfixes with FP32 async copy

This commit is contained in:
Philippe Tillet
2021-02-24 13:36:26 -05:00
parent 11215f0f03
commit 567a1a3d17
3 changed files with 12 additions and 11 deletions

View File

@@ -3,11 +3,12 @@ import triton
import pytest
@pytest.mark.parametrize(
"MODE, TRANS_A, TRANS_B, BLOCK",
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
for block in [16, 32, 64]],
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
[(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"]
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]],
)
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384):
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
# set seed
torch.random.manual_seed(0)
# create inputs

View File

@@ -44,7 +44,7 @@ import torch
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
]),
)
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):