[RUNTIME] Config hook v2.0 (#373)
* Add pre_hook to triton.Config * Use argument names in triton.heuristics * Update base perf * Remove meta from heuristics
This commit is contained in:
@@ -29,16 +29,16 @@ matmul_data = {
|
||||
(1024, 1024, 1024 ) : {'v100': 0.466},
|
||||
(2048, 2048, 2048 ) : {'v100': 0.680},
|
||||
(4096, 4096, 4096 ) : {'v100': 0.831},
|
||||
(8192, 8192, 8192 ) : {'v100': 0.841},
|
||||
(8192, 8192, 8192 ) : {'v100': 0.849},
|
||||
# tall-skinny
|
||||
(16 , 1024, 1024 ) : {'v100': 0.0128},
|
||||
(16 , 4096, 4096 ) : {'v100': 0.0558},
|
||||
(16 , 4096, 4096 ) : {'v100': 0.0883},
|
||||
(16 , 8192, 8192 ) : {'v100': 0.101},
|
||||
(64 , 1024, 1024 ) : {'v100': 0.049},
|
||||
(64 , 4096, 4096 ) : {'v100': 0.211},
|
||||
(64 , 1024, 1024 ) : {'v100': 0.073},
|
||||
(64 , 4096, 4096 ) : {'v100': 0.228},
|
||||
(64 , 8192, 8192 ) : {'v100': 0.360},
|
||||
(1024, 64 , 1024 ) : {'v100': 0.0469},
|
||||
(4096, 64 , 4096 ) : {'v100': 0.198},
|
||||
(1024, 64 , 1024 ) : {'v100': 0.0692},
|
||||
(4096, 64 , 4096 ) : {'v100': 0.223},
|
||||
(8192, 64 , 8192 ) : {'v100': 0.323},
|
||||
# # deep reductions
|
||||
# (64 , 64 , 16384) : {'v100': 0.},
|
||||
@@ -56,7 +56,7 @@ def test_matmul(M, N, K):
|
||||
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
|
||||
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
|
||||
fn = lambda: triton.ops.matmul(a, b)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
@@ -101,7 +101,7 @@ def test_elementwise(N):
|
||||
y = torch.randn_like(z)
|
||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
|
||||
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||
|
@@ -67,7 +67,8 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
torch.manual_seed(0)
|
||||
# nuke kernel decorators -- will set meta-parameters manually
|
||||
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
|
||||
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||
kernel = triton.ops._matmul.kernel
|
||||
decorators = kernel.kernel_decorators
|
||||
kernel.kernel_decorators = []
|
||||
|
@@ -574,7 +574,7 @@ class Kernel:
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
# export symbols visible from self.fn into code-generator object
|
||||
gscope = sys.modules[self.fn.module].__dict__
|
||||
gscope = self.fn.fn.__globals__
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||
try:
|
||||
generator.visit(self.fn.parse())
|
||||
@@ -698,6 +698,7 @@ class Autotuner:
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
@@ -711,11 +712,14 @@ class Autotuner:
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
if key not in self.cache:
|
||||
@@ -726,6 +730,8 @@ class Autotuner:
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
if config.pre_hook != None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
|
||||
@@ -893,11 +899,14 @@ class Config:
|
||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||
:type num_stages: int
|
||||
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||
function are args.
|
||||
"""
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2):
|
||||
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||
self.kwargs = kwargs
|
||||
self.num_warps = num_warps
|
||||
self.num_stages = num_stages
|
||||
self.pre_hook = pre_hook
|
||||
|
||||
|
||||
def autotune(configs, key, reset_to_zero=None):
|
||||
@@ -963,7 +972,7 @@ def heuristics(values):
|
||||
def fun(*args, **meta):
|
||||
for v, heur in values.items():
|
||||
assert v not in meta
|
||||
meta[v] = heur(*args, **meta)
|
||||
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||
return kernel(*args, **meta)
|
||||
|
||||
return fun
|
||||
|
@@ -12,7 +12,7 @@ import torch
|
||||
# ********************************************************
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0,
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _sdd_kernel(
|
||||
|
@@ -11,8 +11,8 @@ def num_warps(n):
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])})
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.jit
|
||||
def _forward(
|
||||
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
@@ -71,8 +71,8 @@ def _forward(
|
||||
tl.store(px, x, mask=check)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
|
||||
@triton.jit
|
||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
|
||||
pidhm = tl.program_id(0)
|
||||
|
@@ -23,8 +23,8 @@ def num_warps(N):
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
|
||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
@@ -48,8 +48,8 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
tl.store(LOSS + row, probs)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
|
||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
|
@@ -2,9 +2,11 @@ import torch
|
||||
import triton.language as tl
|
||||
import triton
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0,
|
||||
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
@@ -18,6 +20,14 @@ import triton
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@@ -26,7 +36,6 @@ def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
LOCKS,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
|
||||
# matrix multiplication
|
||||
@@ -70,18 +79,7 @@ def _kernel(A, B, C, M, N, K,
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
LOCKS = LOCKS + tl.program_id(0)
|
||||
COUNT = LOCKS + tl.num_programs(0)
|
||||
while tl.atomic_cas(LOCKS, 0, 1) == 1:
|
||||
pass
|
||||
count = tl.load(COUNT)
|
||||
if count == 0:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
curr = tl.load(C, mask=mask, other=0.)
|
||||
tl.store(C, acc + curr, mask=mask)
|
||||
tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
|
||||
tl.atomic_xchg(LOCKS, 0)
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
@@ -103,17 +101,13 @@ class _matmul(torch.autograd.Function):
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||
# allocate locks for split-k
|
||||
if a.device not in _matmul._locks:
|
||||
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
|
||||
locks = _matmul._locks[device]
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
locks, GROUP_M=8)
|
||||
GROUP_M=8)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user