[RUNTIME] Config hook v2.0 (#373)
* Add pre_hook to triton.Config * Use argument names in triton.heuristics * Update base perf * Remove meta from heuristics
This commit is contained in:
@@ -29,16 +29,16 @@ matmul_data = {
|
|||||||
(1024, 1024, 1024 ) : {'v100': 0.466},
|
(1024, 1024, 1024 ) : {'v100': 0.466},
|
||||||
(2048, 2048, 2048 ) : {'v100': 0.680},
|
(2048, 2048, 2048 ) : {'v100': 0.680},
|
||||||
(4096, 4096, 4096 ) : {'v100': 0.831},
|
(4096, 4096, 4096 ) : {'v100': 0.831},
|
||||||
(8192, 8192, 8192 ) : {'v100': 0.841},
|
(8192, 8192, 8192 ) : {'v100': 0.849},
|
||||||
# tall-skinny
|
# tall-skinny
|
||||||
(16 , 1024, 1024 ) : {'v100': 0.0128},
|
(16 , 1024, 1024 ) : {'v100': 0.0128},
|
||||||
(16 , 4096, 4096 ) : {'v100': 0.0558},
|
(16 , 4096, 4096 ) : {'v100': 0.0883},
|
||||||
(16 , 8192, 8192 ) : {'v100': 0.101},
|
(16 , 8192, 8192 ) : {'v100': 0.101},
|
||||||
(64 , 1024, 1024 ) : {'v100': 0.049},
|
(64 , 1024, 1024 ) : {'v100': 0.073},
|
||||||
(64 , 4096, 4096 ) : {'v100': 0.211},
|
(64 , 4096, 4096 ) : {'v100': 0.228},
|
||||||
(64 , 8192, 8192 ) : {'v100': 0.360},
|
(64 , 8192, 8192 ) : {'v100': 0.360},
|
||||||
(1024, 64 , 1024 ) : {'v100': 0.0469},
|
(1024, 64 , 1024 ) : {'v100': 0.0692},
|
||||||
(4096, 64 , 4096 ) : {'v100': 0.198},
|
(4096, 64 , 4096 ) : {'v100': 0.223},
|
||||||
(8192, 64 , 8192 ) : {'v100': 0.323},
|
(8192, 64 , 8192 ) : {'v100': 0.323},
|
||||||
# # deep reductions
|
# # deep reductions
|
||||||
# (64 , 64 , 16384) : {'v100': 0.},
|
# (64 , 64 , 16384) : {'v100': 0.},
|
||||||
@@ -56,7 +56,7 @@ def test_matmul(M, N, K):
|
|||||||
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
|
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
|
||||||
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
|
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
|
||||||
fn = lambda: triton.ops.matmul(a, b)
|
fn = lambda: triton.ops.matmul(a, b)
|
||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
|
||||||
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
|
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
|
||||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||||
@@ -101,7 +101,7 @@ def test_elementwise(N):
|
|||||||
y = torch.randn_like(z)
|
y = torch.randn_like(z)
|
||||||
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
|
||||||
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
|
||||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
|
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
|
||||||
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
|
||||||
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
cur_gpu_util = cur_gpu_perf / max_gpu_perf
|
||||||
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
|
||||||
|
@@ -67,7 +67,8 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
# nuke kernel decorators -- will set meta-parameters manually
|
# nuke kernel decorators -- will set meta-parameters manually
|
||||||
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
|
||||||
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
|
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
|
||||||
|
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
|
||||||
kernel = triton.ops._matmul.kernel
|
kernel = triton.ops._matmul.kernel
|
||||||
decorators = kernel.kernel_decorators
|
decorators = kernel.kernel_decorators
|
||||||
kernel.kernel_decorators = []
|
kernel.kernel_decorators = []
|
||||||
|
@@ -574,7 +574,7 @@ class Kernel:
|
|||||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
# export symbols visible from self.fn into code-generator object
|
# export symbols visible from self.fn into code-generator object
|
||||||
gscope = sys.modules[self.fn.module].__dict__
|
gscope = self.fn.fn.__globals__
|
||||||
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
|
||||||
try:
|
try:
|
||||||
generator.visit(self.fn.parse())
|
generator.visit(self.fn.parse())
|
||||||
@@ -698,6 +698,7 @@ class Autotuner:
|
|||||||
for i in self.reset_idx:
|
for i in self.reset_idx:
|
||||||
args[i].zero_()
|
args[i].zero_()
|
||||||
self.hook = _hook
|
self.hook = _hook
|
||||||
|
self.arg_names = arg_names
|
||||||
|
|
||||||
def _bench(self, *args, config, **meta):
|
def _bench(self, *args, config, **meta):
|
||||||
# check for conflicts, i.e. meta-parameters both provided
|
# check for conflicts, i.e. meta-parameters both provided
|
||||||
@@ -711,11 +712,14 @@ class Autotuner:
|
|||||||
# augment meta-parameters with tunable ones
|
# augment meta-parameters with tunable ones
|
||||||
current = dict(meta, **config.kwargs)
|
current = dict(meta, **config.kwargs)
|
||||||
def kernel_call():
|
def kernel_call():
|
||||||
|
if config.pre_hook:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
self.hook(args)
|
self.hook(args)
|
||||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||||
return triton.testing.do_bench(kernel_call)
|
return triton.testing.do_bench(kernel_call)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
self.nargs = dict(zip(self.arg_names, args))
|
||||||
if len(self.configs) > 1:
|
if len(self.configs) > 1:
|
||||||
key = tuple([args[i] for i in self.key_idx])
|
key = tuple([args[i] for i in self.key_idx])
|
||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
@@ -726,6 +730,8 @@ class Autotuner:
|
|||||||
config = self.cache[key]
|
config = self.cache[key]
|
||||||
else:
|
else:
|
||||||
config = self.configs[0]
|
config = self.configs[0]
|
||||||
|
if config.pre_hook != None:
|
||||||
|
config.pre_hook(self.nargs)
|
||||||
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -893,11 +899,14 @@ class Config:
|
|||||||
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
:ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
|
||||||
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
Mostly useful for matrix multiplication workloads on SM80+ GPUs.
|
||||||
:type num_stages: int
|
:type num_stages: int
|
||||||
|
:ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
|
||||||
|
function are args.
|
||||||
"""
|
"""
|
||||||
def __init__(self, kwargs, num_warps=4, num_stages=2):
|
def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None):
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
self.num_warps = num_warps
|
self.num_warps = num_warps
|
||||||
self.num_stages = num_stages
|
self.num_stages = num_stages
|
||||||
|
self.pre_hook = pre_hook
|
||||||
|
|
||||||
|
|
||||||
def autotune(configs, key, reset_to_zero=None):
|
def autotune(configs, key, reset_to_zero=None):
|
||||||
@@ -963,7 +972,7 @@ def heuristics(values):
|
|||||||
def fun(*args, **meta):
|
def fun(*args, **meta):
|
||||||
for v, heur in values.items():
|
for v, heur in values.items():
|
||||||
assert v not in meta
|
assert v not in meta
|
||||||
meta[v] = heur(*args, **meta)
|
meta[v] = heur({**dict(zip(fn.arg_names, args)), **meta})
|
||||||
return kernel(*args, **meta)
|
return kernel(*args, **meta)
|
||||||
|
|
||||||
return fun
|
return fun
|
||||||
|
@@ -12,7 +12,7 @@ import torch
|
|||||||
# ********************************************************
|
# ********************************************************
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda *args, **meta: args[15] % meta['TILE_K'] == 0,
|
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||||
})
|
})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _sdd_kernel(
|
def _sdd_kernel(
|
||||||
|
@@ -11,8 +11,8 @@ def num_warps(n):
|
|||||||
return 16
|
return 16
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[7] * meta['BLOCK'])})
|
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||||
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[7] * meta['BLOCK'])})
|
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _forward(
|
def _forward(
|
||||||
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||||
@@ -71,8 +71,8 @@ def _forward(
|
|||||||
tl.store(px, x, mask=check)
|
tl.store(px, x, mask=check)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4] * meta['BLOCK'])})
|
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||||
@triton.heuristics({'TN': lambda *args, **meta: triton.next_power_of_2(args[4]) * meta['BLOCK']})
|
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
|
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
|
||||||
pidhm = tl.program_id(0)
|
pidhm = tl.program_id(0)
|
||||||
|
@@ -23,8 +23,8 @@ def num_warps(N):
|
|||||||
return 16
|
return 16
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[4])})
|
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[4])})
|
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
@@ -48,8 +48,8 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
|||||||
tl.store(LOSS + row, probs)
|
tl.store(LOSS + row, probs)
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics({'num_warps': lambda *args, **meta: num_warps(args[3])})
|
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||||
@triton.heuristics({'BLOCK': lambda *args, **meta: next_power_of_2(args[3])})
|
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
|
@@ -2,9 +2,11 @@ import torch
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
|
def init_to_zero(name):
|
||||||
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
@triton.heuristics({
|
@triton.heuristics({
|
||||||
'EVEN_K': lambda *args, **meta: args[5] % (meta['BLOCK_K'] * meta['SPLIT_K']) == 0,
|
'EVEN_K': lambda nargs: nargs['K'] % (nargs['BLOCK_K'] * nargs['SPLIT_K']) == 0,
|
||||||
})
|
})
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
@@ -18,6 +20,14 @@ import triton
|
|||||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||||
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 2}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 4}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 8}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
|
triton.Config({'BLOCK_M': 16 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'SPLIT_K': 16}, num_warps=2, pre_hook=init_to_zero('C')),
|
||||||
],
|
],
|
||||||
key=['M', 'N', 'K'],
|
key=['M', 'N', 'K'],
|
||||||
)
|
)
|
||||||
@@ -26,7 +36,6 @@ def _kernel(A, B, C, M, N, K,
|
|||||||
stride_am, stride_ak,
|
stride_am, stride_ak,
|
||||||
stride_bk, stride_bn,
|
stride_bk, stride_bn,
|
||||||
stride_cm, stride_cn,
|
stride_cm, stride_cn,
|
||||||
LOCKS,
|
|
||||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
|
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr):
|
||||||
# matrix multiplication
|
# matrix multiplication
|
||||||
@@ -70,18 +79,7 @@ def _kernel(A, B, C, M, N, K,
|
|||||||
if SPLIT_K == 1:
|
if SPLIT_K == 1:
|
||||||
tl.store(C, acc, mask=mask)
|
tl.store(C, acc, mask=mask)
|
||||||
else:
|
else:
|
||||||
LOCKS = LOCKS + tl.program_id(0)
|
tl.atomic_add(C, acc, mask=mask)
|
||||||
COUNT = LOCKS + tl.num_programs(0)
|
|
||||||
while tl.atomic_cas(LOCKS, 0, 1) == 1:
|
|
||||||
pass
|
|
||||||
count = tl.load(COUNT)
|
|
||||||
if count == 0:
|
|
||||||
tl.store(C, acc, mask=mask)
|
|
||||||
else:
|
|
||||||
curr = tl.load(C, mask=mask, other=0.)
|
|
||||||
tl.store(C, acc + curr, mask=mask)
|
|
||||||
tl.atomic_xchg(COUNT, (count + 1) % SPLIT_K)
|
|
||||||
tl.atomic_xchg(LOCKS, 0)
|
|
||||||
|
|
||||||
|
|
||||||
class _matmul(torch.autograd.Function):
|
class _matmul(torch.autograd.Function):
|
||||||
@@ -103,17 +101,13 @@ class _matmul(torch.autograd.Function):
|
|||||||
_, N = b.shape
|
_, N = b.shape
|
||||||
# allocates output
|
# allocates output
|
||||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||||
# allocate locks for split-k
|
|
||||||
if a.device not in _matmul._locks:
|
|
||||||
_matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device)
|
|
||||||
locks = _matmul._locks[device]
|
|
||||||
# launch kernel
|
# launch kernel
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||||
_kernel[grid](a, b, c, M, N, K,
|
_kernel[grid](a, b, c, M, N, K,
|
||||||
a.stride(0), a.stride(1),
|
a.stride(0), a.stride(1),
|
||||||
b.stride(0), b.stride(1),
|
b.stride(0), b.stride(1),
|
||||||
c.stride(0), c.stride(1),
|
c.stride(0), c.stride(1),
|
||||||
locks, GROUP_M=8)
|
GROUP_M=8)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
Reference in New Issue
Block a user