[RUNTIME] Config hook v2.0 (#373)

* Add pre_hook to triton.Config
* Use argument names in triton.heuristics
* Update base perf
* Remove meta from heuristics
This commit is contained in:
daadaada
2021-11-22 03:20:59 +08:00
committed by GitHub
parent 5693b582ea
commit 1296eb877b
7 changed files with 44 additions and 40 deletions

View File

@@ -29,16 +29,16 @@ matmul_data = {
(1024, 1024, 1024 ) : {'v100': 0.466},
(2048, 2048, 2048 ) : {'v100': 0.680},
(4096, 4096, 4096 ) : {'v100': 0.831},
(8192, 8192, 8192 ) : {'v100': 0.841},
(8192, 8192, 8192 ) : {'v100': 0.849},
# tall-skinny
(16 , 1024, 1024 ) : {'v100': 0.0128},
(16 , 4096, 4096 ) : {'v100': 0.0558},
(16 , 4096, 4096 ) : {'v100': 0.0883},
(16 , 8192, 8192 ) : {'v100': 0.101},
(64 , 1024, 1024 ) : {'v100': 0.049},
(64 , 4096, 4096 ) : {'v100': 0.211},
(64 , 1024, 1024 ) : {'v100': 0.073},
(64 , 4096, 4096 ) : {'v100': 0.228},
(64 , 8192, 8192 ) : {'v100': 0.360},
(1024, 64 , 1024 ) : {'v100': 0.0469},
(4096, 64 , 4096 ) : {'v100': 0.198},
(1024, 64 , 1024 ) : {'v100': 0.0692},
(4096, 64 , 4096 ) : {'v100': 0.223},
(8192, 64 , 8192 ) : {'v100': 0.323},
# # deep reductions
# (64 , 64 , 16384) : {'v100': 0.},
@@ -56,7 +56,7 @@ def test_matmul(M, N, K):
a = torch.randn((M, K), dtype=torch.float16, device='cuda')
b = torch.randn((K, N), dtype=torch.float16, device='cuda')
fn = lambda: triton.ops.matmul(a, b)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=1000)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=1000)
cur_gpu_perf = 2.*M*N*K/ms * 1e-9
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)
@@ -101,7 +101,7 @@ def test_elementwise(N):
y = torch.randn_like(z)
grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), )
fn = lambda: _add[grid](x, y, z, N, BLOCK_SIZE=1024)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=10, rep=250)
ms = triton.testing.do_bench(fn, percentiles=None, warmup=25, rep=250)
cur_gpu_perf = 3.*N*z.element_size()/ms*1e-6
cur_gpu_util = cur_gpu_perf / max_gpu_perf
triton.testing.assert_almost_equal(cur_gpu_util, ref_gpu_util, decimal=2)

View File

@@ -67,7 +67,8 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
torch.manual_seed(0)
# nuke kernel decorators -- will set meta-parameters manually
kwargs = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K}
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE)]
pre_hook = None if SPLIT_K == 1 else lambda nargs: nargs['C'].zero_()
configs = [triton.Config(kwargs=kwargs, num_warps=NWARP, num_stages=NSTAGE, pre_hook=pre_hook)]
kernel = triton.ops._matmul.kernel
decorators = kernel.kernel_decorators
kernel.kernel_decorators = []