2021-02-08 12:16:41 -08:00
|
|
|
import pytest
|
|
|
|
import itertools
|
|
|
|
import triton
|
|
|
|
import torch
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2021-02-21 15:19:39 -08:00
|
|
|
"TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE",
|
2021-02-08 12:16:41 -08:00
|
|
|
itertools.chain(*[
|
|
|
|
[
|
|
|
|
# 1 warp
|
|
|
|
(16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
|
|
|
(16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE),
|
2021-02-21 15:19:39 -08:00
|
|
|
# # 2 warp
|
2021-02-08 12:16:41 -08:00
|
|
|
(64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE),
|
2021-02-21 15:19:39 -08:00
|
|
|
# # 4 warp
|
2021-02-08 12:16:41 -08:00
|
|
|
(128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
# 8 warp
|
2021-02-10 11:01:47 -08:00
|
|
|
# (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
|
|
|
# (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE),
|
|
|
|
# (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE),
|
2021-02-08 12:16:41 -08:00
|
|
|
# split-k
|
|
|
|
(64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
(64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE),
|
|
|
|
# variable input
|
2021-02-21 15:19:39 -08:00
|
|
|
(128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE),
|
2021-02-08 12:16:41 -08:00
|
|
|
(128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE),
|
|
|
|
(128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE),
|
2021-02-21 15:19:39 -08:00
|
|
|
(128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE),
|
2021-02-24 13:36:26 -05:00
|
|
|
] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True]
|
2021-02-21 15:19:39 -08:00
|
|
|
]),
|
|
|
|
)
|
|
|
|
def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE):
|
|
|
|
DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE]
|
2021-02-08 12:16:41 -08:00
|
|
|
torch.manual_seed(0)
|
2021-03-04 01:51:11 -05:00
|
|
|
defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}
|
2021-02-08 12:16:41 -08:00
|
|
|
triton.ops._matmul._kernels = dict()
|
2021-03-04 01:51:11 -05:00
|
|
|
triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)]
|
2021-02-21 15:19:39 -08:00
|
|
|
if M is None:
|
|
|
|
M = TM
|
|
|
|
if N is None:
|
|
|
|
N = TN
|
|
|
|
if K is None:
|
|
|
|
K = TK * SPLITK
|
|
|
|
a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE)
|
|
|
|
b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE)
|
2021-02-08 12:16:41 -08:00
|
|
|
a = a.t() if AT else a
|
|
|
|
b = b.t() if BT else b
|
|
|
|
th_c = torch.matmul(a, b)
|
|
|
|
tt_c = triton.ops.matmul(a, b)
|
2021-02-21 15:19:39 -08:00
|
|
|
assert triton.testing.allclose(th_c, tt_c)
|