[CODEGEN] Bugfixes with FP32 async copy
This commit is contained in:
@@ -3,11 +3,12 @@ import triton
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK",
|
||||
[(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True]
|
||||
for block in [16, 32, 64]],
|
||||
"MODE, TRANS_A, TRANS_B, BLOCK, DTYPE",
|
||||
[(mode, at, bt, block, dtype) for dtype in ["float16", "float32"] for mode in ["sdd", "dsd", "dds"]
|
||||
for at in [False, True] for bt in [False, True] for block in [16, 32, 64]],
|
||||
)
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384):
|
||||
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=128, N=256, K=384):
|
||||
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
# create inputs
|
||||
|
@@ -44,7 +44,7 @@ import torch
|
||||
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
||||
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
||||
] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True]
|
||||
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
]),
|
||||
)
|
||||
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
||||
|
Reference in New Issue
Block a user