From 6fb4800f57e1dc52d54b2f6aa8536398b09ff935 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 3 Feb 2021 13:37:21 -0800 Subject: [PATCH] Improvements w/ Auto-Tuning and standard benchmarks (#57) [PYTHON] Bug-fixes in the auto-tuning module and improvement of the existing API for it --- include/triton/runtime/function.h | 14 ++-- lib/runtime/function.cc | 48 +++++++------ python/src/bindings.cc | 13 ++-- python/tests/test_blocksparse.py | 71 ++++++++++++++++++ python/tests/test_matmul.py | 91 +++++++++++++----------- python/triton/kernel.py | 16 ++--- python/triton/ops/blocksparse/matmul.py | 34 ++++----- python/triton/ops/blocksparse/softmax.py | 4 +- python/triton/ops/conv.py | 10 +-- python/triton/ops/matmul.c | 4 +- python/triton/ops/matmul.py | 31 ++++---- tutorials/01-matmul.cc | 28 ++++---- 12 files changed, 215 insertions(+), 149 deletions(-) diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index b3a0b8039..bc21059c4 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -54,19 +54,13 @@ enum asm_mode_t { ASM_NV_SASS }; -struct options_space_t { - typedef std::pair> define_t; - std::vector defines; - std::vector num_warps; -}; - struct options_t { template T D(const std::string& name) const { return convert(defines.at(name)); } std::unordered_map defines; - size_t num_warps; + int num_warps; }; @@ -111,12 +105,14 @@ public: typedef std::function grid_fn_ty; typedef std::pair> kernel_pair_t; typedef std::map, kernel*> cache_t; + typedef std::vector, int>> autotune_vals_t; private: static void do_loop_nest(std::vector const & ranges, std::function const &)> const & f); public: - function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector &autotune_key = {}); + function(const std::string& src, const options_t& opt, driver::device *device, + const autotune_vals_t& autotune_vals = {}, const std::vector &autotune_key = {}); void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream); // auto-tuning @@ -126,7 +122,7 @@ public: const std::vector get_kernels() { return kernels_; } private: - void init_kernels(const std::string& src, const options_space_t& opt, driver::device *device); + void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device); private: std::vector kernels_; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 2dc9c1a5a..59effcef1 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -224,7 +224,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; // enqueue - stream->enqueue(&*ker_, grid, {opt.num_warps * 32, 1, 1}, args, args_size); + stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size); } std::string kernel::get_asm(asm_mode_t mode) { @@ -282,35 +282,35 @@ void function::do_loop_nest(std::vector const & ranges, return; values[i--] = 0; } - i = D - 1; + i = D - 1; options_t opt; + } } -void function::init_kernels(const std::string& src, const options_space_t& opts, driver::device *device) { - // all ranges - std::vector ranges; - ranges.push_back(opts.num_warps.size()); - for(const auto& x: opts.defines) - ranges.push_back(x.second.size()); - // functor for source with given option +void function::init_kernels(const std::string& src, const options_t& opt, + const autotune_vals_t& confs, driver::device *device) { + // list of all possible configs + // just augment `opt` with each define of `confs` + // and override warp count + size_t num_opts = std::max(confs.size(), (size_t)1); + std::vector opts(num_opts, opt); + for(size_t i = 0; i < confs.size(); i++){ + opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end()); + opts[i].num_warps = confs[i].second; + } + // compile all possible configs + // compilation errors (e.g., too much shared mem) + // will populate `err` std::vector> err; - auto do_make = [&](std::vector params) { - // compilation options - unsigned i = 0; - options_t opt; - opt.num_warps = opts.num_warps[params[i++]]; - for(auto D: opts.defines) - opt.defines[D.first] = D.second[params[i++]]; - // compile + for(const options_t& opt: opts) { try{ kernels_.push_back({opt, std::make_shared(src, opt, device)}); }catch(const exception::base& e){ err.push_back({opt, e.what()}); } - }; - // multi-threaded compilation - do_loop_nest(ranges, do_make); + } + // throw an exception if `err` is not empty if(kernels_.empty()){ std::ostringstream dbg; dbg << "Auto-Tuner could not find any valid configuration:" << std::endl; @@ -357,9 +357,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_ return it->second; } -function::function(const std::string& src, const options_space_t& opt, - driver::device *device, const std::vector& autotune_key) { - init_kernels(src, opt, device); +function::function(const std::string& src, const options_t &opt, driver::device *device, + const autotune_vals_t& autotune_vals, const std::vector& autotune_key) { + // pre-compile all kernels + init_kernels(src, opt, autotune_vals, device); + // find indices of autotune keys auto arg_names = kernels_.at(0).second->get_arg_names(); for(const std::string& name: autotune_key){ auto it = std::find(arg_names.begin(), arg_names.end(), name); diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 2372ac089..069555d91 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) { void register_fn(int op_id, int dev_id, const std::string& src, - const rt::options_space_t& opt, + const rt::options_t& opt, + const rt::function::autotune_vals_t& autotune_vals, const std::vector& autotune_key) { if(tt_devices.find(dev_id) == tt_devices.end()) { driver::device* device; @@ -62,7 +63,7 @@ void register_fn(int op_id, tt_streams[dev_id].reset(stream); } if(id_fn_map.find(op_id) == id_fn_map.end()){ - id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key)); + id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key)); } for(const auto& k: id_fn_map[op_id]->get_kernels()){ const rt::options_t* opt = &k.first; @@ -197,13 +198,9 @@ PYBIND11_MODULE(libtriton, m) { .value("sass", rt::ASM_NV_SASS); pybind11::class_(m, "options", pybind11::dynamic_attr()) - .def_readwrite("num_warps", &rt::options_t::num_warps) - .def_readwrite("defines" , &rt::options_t::defines); - - pybind11::class_(m, "options_space") .def(pybind11::init<>()) - .def_readwrite("num_warps", &rt::options_space_t::num_warps) - .def_readwrite("defines" , &rt::options_space_t::defines); + .def_readwrite("defines" , &rt::options_t::defines) + .def_readwrite("num_warps", &rt::options_t::num_warps); // hooks into triton constructs since frameworks may not use pybind11 m.def("extract_kernels", &extract_kernels); diff --git a/python/tests/test_blocksparse.py b/python/tests/test_blocksparse.py index ba8a98d65..b218e9543 100644 --- a/python/tests/test_blocksparse.py +++ b/python/tests/test_blocksparse.py @@ -15,6 +15,12 @@ def mask_tensor(x, mask, block, value = 0): ret[:, h, i*block: (i+1)*block, j*block: (j+1)*block] = value return ret + + +## ----------------------------------------------------------------------------- +## Unit Tests +## ----------------------------------------------------------------------------- + @pytest.mark.parametrize("MODE, TRANS_A, TRANS_B, BLOCK", [ (mode, at, bt, block) for mode in ['sdd', 'dsd', 'dds']\ @@ -87,3 +93,68 @@ def test_softmax(BLOCK, WIDTH, DTYPE = torch.float16): rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[DTYPE] assert torch.allclose(ry , ty, rtol=rtol, atol=atol) + + +## ----------------------------------------------------------------------------- +## Performance Tests +## ----------------------------------------------------------------------------- + +def do_bench(fn, warmup = 10, rep = 50): + import torch as th + start_event = th.cuda.Event(enable_timing=True) + end_event = th.cuda.Event(enable_timing=True) + ret = fn() + for i in range(warmup): + fn() + th.cuda.synchronize() + start_event.record() + for i in range(rep): + fn() + end_event.record() + th.cuda.synchronize() + time_ms = start_event.elapsed_time(end_event) / rep + return time_ms + +def perf_matmul(BLOCK=64, LAYOUT_MODE = 'tril', OP_MODE = 'sdd', TRANS_A=False, TRANS_B=False, DTYPE = torch.float16, warmup=10, rep=50): + Z, H = 1, 1 + K = 512 + make_layout = { + 'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), + 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), + }[LAYOUT_MODE] + for N in [128, 256, 512, 1024, 2048, 4096]: + # create layout + M, N, K = N, N, N + shape = {'sdd': (M, N), + 'dsd': (K, M) if TRANS_A else (M, K), + 'dds': (N, K) if TRANS_B else (K, N)}[OP_MODE] + layout = make_layout(H, shape[0]//BLOCK, shape[1]//BLOCK) + # create op + op = tt.ops.blocksparse.matmul(layout, BLOCK, OP_MODE, trans_a=TRANS_A, trans_b=TRANS_B) + # inputs + a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device='cuda') + b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device='cuda') + a = sparsify_tensor(a, layout, BLOCK) if OP_MODE == 'dsd' else a + b = sparsify_tensor(b, layout, BLOCK) if OP_MODE == 'dds' else b + ms = do_bench(lambda: op(a, b), warmup=warmup, rep=rep) + num_flops = {'sdd': 2 * Z * K * float(layout.sum()) * BLOCK * BLOCK * 1e-12, + 'dsd': 2 * Z * N * float(layout.sum()) * BLOCK * BLOCK * 1e-12, + 'dds': 2 * Z * M * float(layout.sum()) * BLOCK * BLOCK * 1e-12}[OP_MODE] + triton_tflops = num_flops / ms * 1e3 + +def perf_softmax(BLOCK=64, LAYOUT_MODE = 'tril', DTYPE = torch.float16, warmup=10, rep=50): + Z, H = 1, 1 + K = 512 + make_layout = { + 'tril' : lambda H, M, N: torch.tril(torch.ones((H, M, N), dtype=torch.int64)), + 'dense': lambda H, M, N: torch.ones(H, M, N, dtype=torch.int64), + }[LAYOUT_MODE] + for N in [128, 256, 512, 1024, 2048, 4096]: + layout = make_layout(H, N//BLOCK, N//BLOCK) + a = torch.randn((Z, H, N, N), dtype=DTYPE, device='cuda') + a = sparsify_tensor(a, layout, BLOCK) + op = tt.ops.blocksparse.softmax(layout, BLOCK) + ms = do_bench(lambda: op(a), warmup=warmup, rep=rep) + nbytes = 2 * a.numel() * a.element_size() + triton_gbyps = (nbytes*1e-9) / (ms*1e-3) + print(triton_gbyps) diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py index fb48ce311..9a3e71c3d 100644 --- a/python/tests/test_matmul.py +++ b/python/tests/test_matmul.py @@ -3,57 +3,58 @@ import itertools import triton as tt import torch as th -@pytest.mark.parametrize("TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[ +@pytest.mark.parametrize("TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[ [ # 1 warp - (16, 16, 16, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 16, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 16, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 32, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 32, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 32, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 64, 1, None, None, None, AT, BT, DTYPE), - (64, 16, 64, 1, None, None, None, AT, BT, DTYPE), - (16, 64, 64, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), + (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), # 2 warp - (64, 32, 64, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 64, 2, None, None, None, AT, BT, DTYPE), - (64, 32, 16, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 16, 2, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 2, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), # 4 warp - (128, 64, 16, 4, None, None, None, AT, BT, DTYPE), - (64, 128, 16, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 64, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 64, 4, None, None, None, AT, BT, DTYPE), + (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), + (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE), + (128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE), + (32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE), # 8 warp - (128, 256, 16, 8, None, None, None, AT, BT, DTYPE), - (256, 128, 16, 8, None, None, None, AT, BT, DTYPE), - (256, 128, 32, 8, None, None, None, AT, BT, DTYPE), + (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE), + (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE), + # split-k + (64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), # variable input - (128, 128, 32, 4, 256, 256, 256 , AT, BT, DTYPE), - (128, 128, 32, 4, 384, 128, 640 , AT, BT, DTYPE), - (128, 128, 32, 4, 107, 233, 256 , AT, BT, DTYPE), - (128, 128, 32, 4, 107, 233, 311 , AT, BT, DTYPE) + (128, 128, 32, 1, 4, 256, 256, 256 , AT, BT, DTYPE), + (128, 128, 32, 1, 4, 384, 128, 640 , AT, BT, DTYPE), + (128, 128, 32, 1, 4, 107, 233, 256 , AT, BT, DTYPE), + (128, 128, 32, 1, 4, 107, 233, 311 , AT, BT, DTYPE) ] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True] ])) -def test_op(TM, TN, TK, NWARP, M, N, K, AT, BT, DTYPE): +def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE): DTYPE = {'float16': th.float16, 'float32': th.float32}[DTYPE] th.manual_seed(0) - tt.ops._matmul.kernel = dict() - tt.ops._matmul.TM = [TM] - tt.ops._matmul.TN = [TN] - tt.ops._matmul.TK = [TK] - tt.ops._matmul.num_warps = [NWARP] + tt.ops._matmul._kernels = dict() + tt.ops._matmul._CONFIGS = [({'TM': str(TM) , 'TN': str(TN) , 'TK': str(TK), 'TZ': str(TZ)}, NWARP)] if M is None: M = TM if N is None: N = TN - if K is None: K = TK + if K is None: K = TK*TZ a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5 b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5 a = a.t() if AT else a @@ -81,13 +82,13 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50): return time_ms -def perf_op(dtype=th.float16, warmup=10, rep=50): +def perf_op(AT=False, BT=False, MODE='square', dtype=th.float16, warmup=10, rep=50): import pandas as pd + import matplotlib.pyplot as plt import os - AT, BT = False, False has_cutlass = 'CUTLASS_PROFILER' in os.environ - df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH', 'CUTLASS']) - Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144] + df = pd.DataFrame(columns=['N', 'Triton', 'Torch', 'CUTLASS']) + Ns = [128, 256, 512, 1024, 1536, 2048, 2560, 3072, 4096, 5120, 6144] configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns] for AT, BT, M, N, K in configs: a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 @@ -120,6 +121,10 @@ def perf_op(dtype=th.float16, warmup=10, rep=50): cutlass_tflops = max(df_c['GFLOPs'])/1e3 else: cutlass_tflops = None - df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True) - pd.options.display.float_format = lambda x: '{:.2f}'.format(x) - print(df) \ No newline at end of file + df = df.append({'N': N, 'Triton': triton_tflops, 'Torch': torch_tflops, 'CUTLASS': cutlass_tflops}, ignore_index=True) + # name + AT = {True: 'T', False: 'N'}[AT] + BT = {True: 'T', False: 'N'}[BT] + name = f'{AT}{BT}' + df.plot.line(x='N', y=['Triton', 'Torch', 'CUTLASS'], title = f'{AT}{BT}', ax=ax[0,0], color=['purple', 'blue', 'green']) + plt.savefig(f'matmul-{mode}-{name}.pdf') \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 5ed00faa2..2bb6a98be 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -26,10 +26,8 @@ def th_to_triton(obj): torch.float64: 'double' } if isinstance(obj, torch.dtype): - return [tys[obj]] - if isinstance(obj, list): - return [th_to_triton(x)[0] for x in obj] - return [str(obj)] + return tys[obj] + return str(obj) def cdiv(a, b): return libtriton.cdiv(a, b) @@ -45,17 +43,15 @@ def read(path, kernel_names=[]): source = libtriton.extract_kernels(source, kernel_names) return source - - class kernel: - def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []): + def __init__(self, src, device, defines = dict(), num_warps = 4, autotune_vals = [], autotune_key = []): # check if src is empty if src == '': raise ValueError('Kernel source code is empty') self.src = src - self.opt = libtriton.options_space() - self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()] + self.opt = libtriton.options() + self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} self.opt.num_warps = num_warps # device assert device.type in ['cuda', 'cpu'] @@ -65,7 +61,7 @@ class kernel: self.device = -1 # C++ function wrapper self.op_id = libtriton.make_op_id() - libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key) + libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) # debug mode self.is_debug = 'TRITON_DEBUG' in os.environ # signature diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index a3ae38d7e..7d5abe948 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -81,7 +81,7 @@ class _matmul(torch.autograd.Function): @staticmethod def make_sdd_lut(layout, block, dtype, device): - start_width = 64 // block + start_width = 128 // block superblocks = libtriton.superblock(layout.type(torch.int32), start_width) luts, widths, packs = [], [], [] for size, nnz in superblocks: @@ -126,22 +126,18 @@ class _matmul(torch.autograd.Function): num_lock = 1 key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) if key not in _matmul.sdd_cache: - F32TK = [8, 16] - #F16TK = [16] - #F16TK += [32] if is_32_multiple else [] - #F16TK += [64] if is_64_multiple else [] - F16TK = [64] - TK = {torch.float32: F32TK, - torch.float16: F16TK}[dtype] - defines = {'TM': block*pack, 'TN': block*pack, 'TMN': block*block*pack*pack, 'BLOCK': block, - 'TK': TK, 'TYPE': dtype, + defines = {'TM': block*pack, 'TN': block*pack, + 'TMN': block*block*pack*pack, + 'BLOCK': block, + 'TK': 32, + 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1', 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc', 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'} - _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines, num_warps=[1, 2, 4]) + _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines) kernel = _matmul.sdd_cache[key] # create output @@ -270,9 +266,9 @@ class _matmul(torch.autograd.Function): # kernel key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _matmul.dds_cache: - TM = [64, 128] if dtype == torch.float32 else [64, 128, 256] - TK = [8] if dtype == torch.float32 else [16] - defines = {'TM': TM, 'TN': block, 'TK': TK, + defines = {'TM': 128, + 'TN': block, + 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda', @@ -283,7 +279,7 @@ class _matmul(torch.autograd.Function): 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dds_kernel', 'DDS': True} - _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4]) + _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) kernel = _matmul.dds_cache[key] # output CS0 = AS0 @@ -315,9 +311,9 @@ class _matmul(torch.autograd.Function): # kernel key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) if key not in _matmul.dsd_cache: - TN = [64, 128] if dtype == torch.float32 else [64, 128] - TK = [8] if dtype == torch.float32 else [16] - defines = {'TM': block, 'TN': TN, 'TK': TK, + defines = {'TM': block, + 'TN': 128, + 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block, @@ -328,7 +324,7 @@ class _matmul(torch.autograd.Function): 'STRIDE_CN': 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True} - _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines, num_warps=[4]) + _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines) kernel = _matmul.dsd_cache[key] # output CS0 = BS0 diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index 6a2cfa251..45cd7bdf9 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -48,7 +48,7 @@ class _softmax(torch.autograd.Function): # just-in-time compile kernel key = (block, device, dtype, num_warps, TN, apply_scale, apply_rpe, apply_kp_mask, apply_attn_mask, kp_mask_mode, attn_mask_mode) if key not in cache: - defines = {'TM': [1], 'TN': [TN], 'TYPE': dtype, 'BLOCK': block, + defines = {'TM': 1, 'TN': TN, 'TYPE': dtype, 'BLOCK': block, 'INFINITY': {torch.float32: 'F32_INFINITY', torch.float16: 'F16_INFINITY'}[dtype]} if apply_scale: @@ -63,7 +63,7 @@ class _softmax(torch.autograd.Function): defines['APPLY_ATTN_MASK'] = True if attn_mask_mode == 'mul': defines['ATTN_MASK_MUL'] = True - kernel = triton.kernel(src, device=device, defines=defines, num_warps=[num_warps]) + kernel = triton.kernel(src, device=device, defines=defines, num_warps=num_warps) cache[key] = kernel return cache[key] diff --git a/python/triton/ops/conv.py b/python/triton/ops/conv.py index 4e2961898..95a1ad201 100644 --- a/python/triton/ops/conv.py +++ b/python/triton/ops/conv.py @@ -29,10 +29,10 @@ class _conv(torch.autograd.Function): TK = 16 defines = { 'TYPE' : dtype, - 'TM' : [32, 64, 128], - 'TN' : [32, 64, 128], - 'TK' : [TK], - 'TZ' : [1], + 'TM' : 64, + 'TN' : 64, + 'TK' : TK, + 'TZ' : 1, 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, } idx = torch.arange(CI*R*S) @@ -40,7 +40,7 @@ class _conv(torch.autograd.Function): nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3) delta = delta.type(torch.int32).cuda() - _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, num_warps=[4], defines=defines)) + _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, device=device, defines=defines)) delta, kernel = _conv.kernel[dtype] # allocate output c = torch.empty([Z, CO, P, Q], dtype=dtype, device=device) diff --git a/python/triton/ops/matmul.c b/python/triton/ops/matmul.c index 95f21fde2..2875e54e2 100644 --- a/python/triton/ops/matmul.c +++ b/python/triton/ops/matmul.c @@ -83,8 +83,8 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16), *?(checkc) pc = c; #else // accumulate partial result using spin-locks - int *plock = locks + rid; - int *pcount = plock + get_num_programs(0) * get_num_programs(1); + int *plock = locks + pid; + int *pcount = plock + get_num_programs(0); for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); int count = *pcount; if(count == 0) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index d1df80e96..7912badec 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -5,11 +5,21 @@ import os class _matmul(torch.autograd.Function): src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) - TM = [128] - TN = [128] - TK = [32] - TZ = 1 - num_warps = [4] + _DEFAULT_CONFIGS = [ + ({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), + ({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), + ({'TM': '128', 'TN': '64' , 'TK': '32', 'TZ': '1'}, 4), + ({'TM': '64' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 4), + ({'TM': '32' , 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), + ({'TM': '128', 'TN': '32' , 'TK': '64', 'TZ': '1'}, 4), + ({'TM': '64' , 'TN': '32' , 'TK': '64', 'TZ': '1'}, 2), + ({'TM': '32' , 'TN': '64' , 'TK': '64', 'TZ': '1'}, 2), + ({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), + ({'TM': '32' , 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), + ({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), + ({'TM': '128' , 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), + ] + _CONFIGS = _DEFAULT_CONFIGS @staticmethod def largest_pow2_divisor(N): @@ -41,7 +51,7 @@ class _matmul(torch.autograd.Function): lda_pow2_div = _matmul.largest_pow2_divisor(lda) ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) - is_tk_div_k = K % 32 == 0 + is_tk_div_k = K % 64 == 0 key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) if key not in _matmul._kernels: defines = { @@ -53,13 +63,10 @@ class _matmul(torch.autograd.Function): 'LDA_POW2_DIV': lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, - 'TM' : _matmul.TM, - 'TN' : _matmul.TN, - 'TK' : _matmul.TK, - 'TZ' : _matmul.TZ, 'IS_TK_DIV_K' : int(is_tk_div_k) } - _matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K']) + _matmul._kernels[key] = triton.kernel(_matmul.src, device, defines=defines, + autotune_vals = _matmul._CONFIGS, autotune_key=['M', 'N', 'K']) kernel = _matmul._kernels[key] # # locks for split-k if device not in _matmul._locks: @@ -68,7 +75,7 @@ class _matmul(torch.autograd.Function): # enqueue alpha = 1. args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()] - grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, 1] + grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ] kernel(*args, grid=grid) return c diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc index 6169a8e6d..c878a5a54 100644 --- a/tutorials/01-matmul.cc +++ b/tutorials/01-matmul.cc @@ -158,21 +158,17 @@ float triton_dot(drv::context* context, drv::stream* stream, stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); // macros - rt::options_space_t opts; - // A access patterns - opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }}); - opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }}); - // B access patterns - opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }}); - opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }}); - // data-type - opts.defines.push_back({"TYPE", {ty}}); - // tile sizes - opts.defines.push_back({"TM", {"128"}}); - opts.defines.push_back({"TN", {"128"}}); - opts.defines.push_back({"TK", {"32"}}); - opts.defines.push_back({"TZ", {"1"}}); - opts.num_warps = {4}; + rt::options_t opt; + opt.defines["STRIDE_AK"] = AT? "1" : "lda"; + opt.defines["STRIDE_AM"] = AT? "lda" : "1"; + opt.defines["STRIDE_BK"] = BT? "ldb" : "1"; + opt.defines["STRIDE_BN"] = BT? "1" : "ldb"; + opt.defines["TYPE"] = ty; + opt.defines["TM"] = "128"; + opt.defines["TN"] = "128"; + opt.defines["TK"] = "32" ; + opt.defines["TZ"] = "1"; + opt.num_warps = 4; // arguments std::stringstream oss; rt::add_arg(oss, *da->cu()); @@ -187,7 +183,7 @@ float triton_dot(drv::context* context, drv::stream* stream, rt::add_arg(oss, ldc); rt::add_arg(oss, *dlocks->cu()); // function - rt::function function(src::dot, opts, device); + rt::function function(src::dot, opt, device); // std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl; // grid auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };