From a714b6b8561dbadf9e11eb30423159fe8bc30fe9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 16 Aug 2021 22:56:21 -0700 Subject: [PATCH] [PYTHON] re-activated auto-tuner configurations for triton.ops.matmul (#212) --- python/test/test_matmul.py | 2 +- python/triton/ops/matmul.py | 40 +++++++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index e5c1961ef..dbf1974ce 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -66,7 +66,7 @@ import torch def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually - META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8} + META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K} configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)] kernel = triton.ops._matmul.kernel decorators = kernel.kernel_decorators diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index f5bc6afa3..7cac0a264 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -8,19 +8,25 @@ import triton }) @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4), - # triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ - # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ - # triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ - # triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4), - # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\ - # triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),\ - # triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), ], - key=['M', 'N', 'K'] + key=['M', 'N', 'K'], ) @triton.jit -def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META): +def _kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + LOCKS, **META): # extract meta-parameters BLOCK_M = META['BLOCK_M'] BLOCK_N = META['BLOCK_N'] @@ -40,12 +46,14 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rk = tl.arange(0, BLOCK_K) # pointers K = K // SPLIT_K - A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn) + A = A + (pid_z * K * stride_ak + ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): if META['EVEN_K']: @@ -106,7 +114,13 @@ class _matmul(torch.autograd.Function): locks = _matmul._locks[device] # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), locks) + _kernel[grid](a, b, c, + M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + locks, + GROUP_M=8) # done return c