diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index f3fc41298..b3a0b8039 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -82,6 +82,7 @@ public: void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector& grid) const; // getters const std::vector& get_sig() const { return sig_; } + const std::vector& get_arg_names() const { return arg_names_; } std::string get_asm(asm_mode_t mode); private: @@ -96,6 +97,7 @@ private: driver::device* dev_; // signature std::vector sig_; + std::vector arg_names_; // triton context for parsing ir::context ctx_; // handles @@ -114,7 +116,7 @@ private: static void do_loop_nest(std::vector const & ranges, std::function const &)> const & f); public: - function(const std::string& src, const options_space_t& opt, driver::device *device); + function(const std::string& src, const options_space_t& opt, driver::device *device, const std::vector &autotune_key = {}); void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream); // auto-tuning @@ -129,6 +131,9 @@ private: private: std::vector kernels_; std::map, kernel*> cache_; + std::vector key_idxs_; + std::vector arg_size_; + std::vector arg_off_; }; } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 3f8c7cafc..05049244b 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -211,6 +211,8 @@ kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev init_ir(preheader() + src); init_ker(); init_sig(); + for(auto arg: ir_->get_function_list()[0]->args()) + arg_names_.push_back(arg->get_name()); } void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector& _grid) const{ @@ -328,7 +330,11 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_ if(kernels_.size() == 1) return &*kernels_.begin()->second; // auto-tuning key - std::vector key; + std::vector key(key_idxs_.size()); + for(size_t i = 0; i < key.size(); i++){ + int idx = key_idxs_[i]; + std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]); + } auto it = cache_.find(key); if(it != cache_.end()) return it->second; @@ -350,8 +356,26 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_ return it->second; } -function::function(const std::string& src, const options_space_t& opt, driver::device *device) { +function::function(const std::string& src, const options_space_t& opt, + driver::device *device, const std::vector& autotune_key) { init_kernels(src, opt, device); + auto arg_names = kernels_.at(0).second->get_arg_names(); + for(const std::string& name: autotune_key){ + auto it = std::find(arg_names.begin(), arg_names.end(), name); + if(it == arg_names.end()) + throw std::runtime_error(name + " is not a valid argument name"); + key_idxs_.push_back(std::distance(arg_names.begin(), it)); + } + // argument size and offset + auto tys = kernels_.at(0).second->get_sig(); + size_t curr = 0; + for(arg_type ty: tys){ + arg_size_.push_back(size_of(ty)); + arg_off_.push_back(curr); + curr += arg_size_.back(); + } + + } void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) { diff --git a/python/src/bindings.cc b/python/src/bindings.cc index fbd41ef4c..2372ac089 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -45,7 +45,8 @@ void delete_grid(const map_key_t& key) { void register_fn(int op_id, int dev_id, const std::string& src, - const rt::options_space_t& opt) { + const rt::options_space_t& opt, + const std::vector& autotune_key) { if(tt_devices.find(dev_id) == tt_devices.end()) { driver::device* device; driver::stream* stream; @@ -61,7 +62,7 @@ void register_fn(int op_id, tt_streams[dev_id].reset(stream); } if(id_fn_map.find(op_id) == id_fn_map.end()){ - id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id])); + id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_key)); } for(const auto& k: id_fn_map[op_id]->get_kernels()){ const rt::options_t* opt = &k.first; diff --git a/python/tests/test_matmul.py b/python/tests/test_matmul.py index e00e6b5b1..2857cde2f 100644 --- a/python/tests/test_matmul.py +++ b/python/tests/test_matmul.py @@ -78,21 +78,30 @@ def do_bench(fn, flops = 0, warmup = 10, rep = 50): end_event.record() th.cuda.synchronize() time_ms = start_event.elapsed_time(end_event) / rep - return time_ms, flops/time_ms*1e-9, ret + return time_ms def perf_op(dtype=th.float16, warmup=10, rep=50): - AT, BT = False, False import pandas as pd + AT, BT = False, False df = pd.DataFrame(columns=['AT', 'BT', 'N', 'TRITON', 'TORCH']) - Ns = [128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192] + # Ns = [128, 256, 512, 1024, 2048, 3072, 4096, 6144, 8192] + Ns = [8192] configs = [(AT, BT, N, N, N) for AT in [False, True] for BT in [False, True] for N in Ns] for AT, BT, M, N, K in configs: a = th.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 b = th.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 if AT: a = a.t() if BT: b = b.t() - TH_MS, TH_TFLOPS, _ = do_bench(lambda: th.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep) - TT_MS, TT_TFLOPS, _ = do_bench(lambda: tt.ops.matmul(a, b), flops = M*N*K*2, warmup = warmup, rep = rep) - df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': TT_TFLOPS, 'TORCH': TH_TFLOPS}, ignore_index=True) + # benchmarks + torch_ms = do_bench(lambda: th.matmul(a, b), warmup = warmup, rep = rep) + triton_ms = do_bench(lambda: tt.ops.matmul(a, b), warmup = warmup, rep = rep) + # store result + num_flops = 2*M*N*K + torch_tflops = num_flops / torch_ms * 1e-9 + triton_tflops = num_flops / triton_ms * 1e-9 + #print(min(alpha*bandwidth*1e-12, max_tflops), triton_tflops) + #./tools/profiler/cutlass_profiler --m=8192 --n=8192 --k=8192 --A=f16:column --B=f16:column --C=f16:column --accum=f32 --operation=gemm + df = df.append({'AT': AT, 'BT': BT, 'N': N, 'TRITON': triton_tflops, 'TORCH': torch_tflops}, ignore_index=True) + pd.options.display.float_format = lambda x: '{:.2f}'.format(x) print(df) \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index b8c8a7373..5ed00faa2 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -49,7 +49,7 @@ def read(path, kernel_names=[]): class kernel: - def __init__(self, src, device, defines = dict(), num_warps = [4]): + def __init__(self, src, device, defines = dict(), num_warps = [4], autotune_key = []): # check if src is empty if src == '': raise ValueError('Kernel source code is empty') @@ -65,7 +65,7 @@ class kernel: self.device = -1 # C++ function wrapper self.op_id = libtriton.make_op_id() - libtriton.register_fn(self.op_id, self.device, self.src, self.opt) + libtriton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_key) # debug mode self.is_debug = 'TRITON_DEBUG' in os.environ # signature diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index c7fcfd2a2..d1df80e96 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -59,7 +59,7 @@ class _matmul(torch.autograd.Function): 'TZ' : _matmul.TZ, 'IS_TK_DIV_K' : int(is_tk_div_k) } - _matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines) + _matmul._kernels[key] = triton.kernel(_matmul.src, device, num_warps=_matmul.num_warps, defines=defines, autotune_key=['M', 'N', 'K']) kernel = _matmul._kernels[key] # # locks for split-k if device not in _matmul._locks: diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc index 8f4e67643..6169a8e6d 100644 --- a/tutorials/01-matmul.cc +++ b/tutorials/01-matmul.cc @@ -173,7 +173,6 @@ float triton_dot(drv::context* context, drv::stream* stream, opts.defines.push_back({"TK", {"32"}}); opts.defines.push_back({"TZ", {"1"}}); opts.num_warps = {4}; - // arguments std::stringstream oss; rt::add_arg(oss, *da->cu());