63 lines
3.1 KiB
Python
63 lines
3.1 KiB
Python
import pytest
|
|
import itertools
|
|
import triton
|
|
import torch
|
|
|
|
@pytest.mark.parametrize(
|
|
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
|
|
itertools.chain(*[
|
|
[
|
|
# 1 warp
|
|
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
# 2 warp
|
|
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
# 4 warp
|
|
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
# 8 warp
|
|
# (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
|
# (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
|
# (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
|
# split-k
|
|
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
|
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
|
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
|
# variable input
|
|
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
|
|
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
|
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
|
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
|
|
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
|
|
]))
|
|
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
|
|
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
|
|
torch.manual_seed(0)
|
|
triton.ops._matmul._kernels = dict()
|
|
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
|
|
if M is None: M = TM
|
|
if N is None: N = TN
|
|
if K is None: K = TK * TZ
|
|
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
|
|
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
|
|
a = a.t() if AT else a
|
|
b = b.t() if BT else b
|
|
th_c = torch.matmul(a, b)
|
|
tt_c = triton.ops.matmul(a, b)
|
|
assert triton.testing.allclose(th_c, tt_c) |