[PYTHON] re-activated auto-tuner configurations for triton.ops.matmul (#212)
This commit is contained in:
@@ -66,7 +66,7 @@ import torch
|
|||||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
# nuke kernel decorators -- will set meta-parameters manually
|
# nuke kernel decorators -- will set meta-parameters manually
|
||||||
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8}
|
META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||||
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)]
|
||||||
kernel = triton.ops._matmul.kernel
|
kernel = triton.ops._matmul.kernel
|
||||||
decorators = kernel.kernel_decorators
|
decorators = kernel.kernel_decorators
|
||||||
|
@@ -8,19 +8,25 @@ import triton
|
|||||||
})
|
})
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=4),\
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
# triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),\
|
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
# triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 64, 'SPLIT_K': 1, 'GROUP_M': 8}, num_warps=2),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
|
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
],
|
],
|
||||||
key=['M', 'N', 'K']
|
key=['M', 'N', 'K'],
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, LOCKS, **META):
|
def _kernel(A, B, C, M, N, K,
|
||||||
|
stride_am, stride_ak,
|
||||||
|
stride_bk, stride_bn,
|
||||||
|
stride_cm, stride_cn,
|
||||||
|
LOCKS, **META):
|
||||||
# extract meta-parameters
|
# extract meta-parameters
|
||||||
BLOCK_M = META['BLOCK_M']
|
BLOCK_M = META['BLOCK_M']
|
||||||
BLOCK_N = META['BLOCK_N']
|
BLOCK_N = META['BLOCK_N']
|
||||||
@@ -40,12 +46,14 @@ def _kernel(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride
|
|||||||
pid_n = (pid % width) // (group_size)
|
pid_n = (pid % width) // (group_size)
|
||||||
# do matrix multiplication
|
# do matrix multiplication
|
||||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||||
rk = tl.arange(0, BLOCK_K)
|
rk = tl.arange(0, BLOCK_K)
|
||||||
# pointers
|
# pointers
|
||||||
K = K // SPLIT_K
|
K = K // SPLIT_K
|
||||||
A = A + (pid_z * K * stride_ak + rm[:, None] * stride_am + rk[None, :] * stride_ak)
|
A = A + (pid_z * K * stride_ak + ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||||
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
B = B + (pid_z * K * stride_bk + rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(K, 0, -BLOCK_K):
|
for k in range(K, 0, -BLOCK_K):
|
||||||
if META['EVEN_K']:
|
if META['EVEN_K']:
|
||||||
@@ -106,7 +114,13 @@ class _matmul(torch.autograd.Function):
|
|||||||
locks = _matmul._locks[device]
|
locks = _matmul._locks[device]
|
||||||
# launch kernel
|
# launch kernel
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||||
_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), locks)
|
_kernel[grid](a, b, c,
|
||||||
|
M, N, K,
|
||||||
|
a.stride(0), a.stride(1),
|
||||||
|
b.stride(0), b.stride(1),
|
||||||
|
c.stride(0), c.stride(1),
|
||||||
|
locks,
|
||||||
|
GROUP_M=8)
|
||||||
# done
|
# done
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user