Files
triton/python/test/test_matmul.py
2021-07-27 12:38:49 -07:00

63 lines
3.1 KiB
Python

import pytest
import itertools
import triton
import torch
@pytest.mark.parametrize(
"TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE",
itertools.chain(*[
[
# 1 warp
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
# 2 warp
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
# 4 warp
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
# 8 warp
# (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
# (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
# (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
# split-k
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
# variable input
(128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE)
] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True]
]))
def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE):
DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE]
torch.manual_seed(0)
triton.ops._matmul._kernels = dict()
triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)]
if M is None: M = TM
if N is None: N = TN
if K is None: K = TK * TZ
a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5
b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5
a = a.t() if AT else a
b = b.t() if BT else b
th_c = torch.matmul(a, b)
tt_c = triton.ops.matmul(a, b)
assert triton.testing.allclose(th_c, tt_c)