From 62835a09795ca6e5adfe2bfa70fd1628ebf70948 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 4 Mar 2021 01:51:11 -0500 Subject: [PATCH] [RUNTIME] Added auto-alignment mechanism (#71) This PR adds an automatic memory alignment mechanism in the Triton runtime. Specifically, the JIT compiler detects the alignment (in bytes) of each pointer argument as well as the largest power of two divisor (between 1 and 16) of each integer argument. Proper .aligned and .multipleof attributes are then added to the Triton-IR on-the-fly for all auto-tunable kernels. There is a cache that remembers all the kernels compiled for each possible configuration. This PR also includes substantial cleaning of the Python API. This adds 2-3us overhead, mostly due to accessing integer #defines from the auto-tuned compilation options. The previous solution was slightly faster but hacky and potentially unsafe, so this is preferred for now. --- include/triton/ir/module.h | 5 +- include/triton/runtime/arg.h | 66 +---- include/triton/runtime/function.h | 129 +++++---- lib/ir/module.cc | 4 +- lib/runtime/function.cc | 354 +++++++++++------------- python/bench/bench_matmul.py | 43 ++- python/bench/run.py | 2 +- python/src/torch/utils.cc | 50 +--- python/src/triton.cc | 209 +++++++------- python/test/test_matmul.py | 3 +- python/triton/kernel.py | 81 ++---- python/triton/ops/blocksparse/matmul.c | 26 +- python/triton/ops/blocksparse/softmax.c | 32 +-- python/triton/ops/conv.c | 237 ++++++++-------- python/triton/ops/cross_entropy.c | 11 +- python/triton/ops/matmul.c | 14 +- python/triton/ops/matmul.py | 24 +- python/triton/testing.py | 9 +- python/tutorials/01-vector-add.py | 76 +++++ 19 files changed, 668 insertions(+), 707 deletions(-) create mode 100644 python/tutorials/01-vector-add.py diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 6db128198..0d9c625f1 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -10,6 +10,7 @@ #include #include "triton/ir/builder.h" #include "triton/ir/metadata.h" +#include "triton/ir/context.h" namespace triton{ @@ -60,7 +61,7 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name, context &ctx); + module(const std::string &name); context& get_context(); builder& get_builder(); // Setters @@ -94,7 +95,7 @@ public: private: std::string name_; - context &context_; + context context_; builder builder_; std::map values_; std::map types_; diff --git a/include/triton/runtime/arg.h b/include/triton/runtime/arg.h index acee97ae8..7ba2f63d3 100644 --- a/include/triton/runtime/arg.h +++ b/include/triton/runtime/arg.h @@ -5,6 +5,7 @@ #include #include +#include namespace triton{ namespace ir{ @@ -17,73 +18,8 @@ namespace driver{ namespace runtime { -enum arg_type { - INT1_T, - INT8_T, - INT16_T, - INT32_T, - INT64_T, - HALF_T, - FLOAT_T, - DOUBLE_T, - BUFFER_T -}; - -arg_type convert(ir::type *ty); -inline size_t size_of(arg_type ty){ - switch(ty){ - case INT1_T: return 1; - case INT8_T: return 1; - case INT16_T: return 2; - case INT32_T: return 4; - case INT64_T: return 8; - case HALF_T: return 2; - case FLOAT_T: return 4; - case DOUBLE_T: return 8; - case BUFFER_T: return 8; - default: throw std::runtime_error("unknown type"); - } -} - -inline bool is_int_type(arg_type ty){ - return ty == INT1_T || ty == INT8_T || ty == INT16_T || - ty == INT32_T || ty == INT64_T; -} - -class arg { -public: - union value_t { - bool int1; - int8_t int8; - int16_t int16; - int32_t int32; - int64_t int64; - uint16_t fp16; - float fp32; - double fp64; - driver::buffer* buf; - }; - -public: - // construct from primitive types - arg(arg_type ty, value_t val): ty_(ty) { val_ = val; } - arg(int32_t x): ty_(INT32_T) { val_.int32 = x; } - arg(int64_t x): ty_(INT64_T) { val_.int64 = x; } - arg(float x): ty_(FLOAT_T) { val_.fp32 = x; } - arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; } - arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; } - // accessors - arg_type type() const { return ty_; } - void* data() const { return (void*)&val_; } - driver::buffer* buffer() const { return val_.buf; } - - -private: - arg_type ty_; - value_t val_; -}; } } diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 501bddd39..ddfe78776 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -11,6 +11,7 @@ #include #include // codegen +#include "triton/ir/function.h" #include "triton/ir/context.h" #include "triton/runtime/arg.h" #include "triton/runtime/error.h" @@ -37,63 +38,86 @@ class context; namespace triton{ namespace runtime{ -typedef std::vector grid_t; -typedef std::map params_t; -template inline T convert(const std::string& name); -template<> inline long convert(const std::string& name) { return std::stol(name); } -template<> inline int convert(const std::string& name) { return std::stoi(name); } + +/* ------------------------- */ +/* Compilation options */ +/* ------------------------- */ + +struct options_t { + template + T D(const std::string& name) const { + return std::stoi(defines.at(name)); + } + std::unordered_map defines; + int num_warps; +}; + +/* ------------------------- */ +/* Runtime arguments */ +/* ------------------------- */ + +enum arg_type { + INT1_T, + INT8_T, + INT16_T, + INT32_T, + INT64_T, + HALF_T, + FLOAT_T, + DOUBLE_T, + BUFFER_T +}; + +inline size_t size_of(arg_type ty){ + switch(ty){ + case INT1_T : return 1; + case INT8_T : return 1; + case INT16_T : return 2; + case INT32_T : return 4; + case INT64_T : return 8; + case HALF_T : return 2; + case FLOAT_T : return 4; + case DOUBLE_T: return 8; + case BUFFER_T: return 8; + default: throw std::runtime_error("unknown type"); + } +} template void add_arg(std::stringstream& ss, T arg) { ss.write((char*)&arg, sizeof(T)); } + +/* ------------------------- */ +/* ------------------------- */ + enum asm_mode_t { ASM_LLIR, ASM_NV_PTX, ASM_NV_SASS }; -struct options_t { - template - T D(const std::string& name) const { - return convert(defines.at(name)); - } - std::unordered_map defines; - int num_warps; -}; - - -/* ------------------------- */ - class kernel{ -private: - static std::string preheader(); - static arg_type convert(ir::type *ty); +public: + typedef std::vector grid_t; public: - kernel(const std::string& src, const options_t& opt, driver::device *device); - void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector& grid) const; - // getters - const std::vector& get_sig() const { return sig_; } - const std::vector& get_arg_names() const { return arg_names_; } - std::string get_asm(asm_mode_t mode); + static std::shared_ptr src_to_ir(const std::string& src, const options_t& opt); + static std::tuple, + std::shared_ptr, + size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt); -private: - void init_ir (const std::string &src); - void init_ker(); - void init_sig(); +public: + kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map &attrs = {}); + void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const; + std::string get_asm(asm_mode_t mode); public: const options_t opt; private: driver::device* dev_; - // signature - std::vector sig_; - std::vector arg_names_; - // triton context for parsing - ir::context ctx_; // handles std::shared_ptr ir_; std::shared_ptr mod_; @@ -102,36 +126,37 @@ private: size_t shared_mem_; }; +struct config { + std::map defines; + int num_warps; +}; + class function { public: - typedef std::function grid_fn_ty; + typedef std::function grid_fn_ty; typedef std::pair> kernel_pair_t; typedef std::map, kernel*> cache_t; - typedef std::vector, int>> autotune_vals_t; + typedef std::vector autotune_confs_t; -private: - static void do_loop_nest(std::vector const & ranges, - std::function const &)> const & f); public: function(const std::string& src, const options_t& opt, driver::device *device, - const autotune_vals_t& autotune_vals = {}, const std::vector &autotune_key = {}); - void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); - void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream); - // auto-tuning - cache_t::iterator find_in_cache(void* args, size_t args_size); - kernel* autotune(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); - // getters - const std::vector get_kernels() { return kernels_; } + const std::vector& tune_confs = {}, const std::vector &tune_key = {}); + kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream); + void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream); + const std::vector get_signature() { return sig_; } private: - void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device); - -private: - std::vector kernels_; + std::map, std::vector>> kernels_; std::map, kernel*> cache_; + std::vector sig_; + std::vector align_idxs_; + std::vector int_idxs_; std::vector key_idxs_; std::vector arg_size_; std::vector arg_off_; + std::vector opts_; + std::string src_; + driver::device* device_; }; } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 28edb4e4f..d31f329c0 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -10,8 +10,8 @@ namespace triton{ namespace ir{ /* Module */ -module::module(const std::string &name, context &ctx) - : name_(name), context_(ctx), builder_(ctx) { +module::module(const std::string &name) + : name_(name), builder_(context_) { sealed_blocks_.insert(nullptr); } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 5e2caa062..69799e557 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -40,7 +40,6 @@ #include #include -std::mutex mut; namespace triton{ namespace runtime { @@ -49,22 +48,9 @@ namespace runtime { /* --------------------------------- */ /* --------------------------------- */ -arg_type kernel::convert(ir::type *ty) { - if(ty->is_integer_ty(1)) return INT1_T; - if(ty->is_integer_ty(8)) return INT8_T; - if(ty->is_integer_ty(16)) return INT16_T; - if(ty->is_integer_ty(32)) return INT32_T; - if(ty->is_integer_ty(64)) return INT64_T; - if(ty->is_half_ty()) return HALF_T; - if(ty->is_float_ty()) return FLOAT_T; - if(ty->is_double_ty()) return DOUBLE_T; - if(ty->is_pointer_ty()) return BUFFER_T; - throw std::runtime_error("unknown type"); -} - - -std::string kernel::preheader() { - return R"( +std::shared_ptr kernel::src_to_ir(const std::string& _src, const options_t& opt) { + std::string src = +R"( #define bool _Bool #define true 1 #define false 0 @@ -116,9 +102,7 @@ typedef short int16; typedef int int32; typedef long int64; )"; -} - -void kernel::init_ir(const std::string& src) { + src += _src; // pre-process TokenSequence tokens; Preprocessor cpp(&src, true); @@ -129,21 +113,21 @@ void kernel::init_ir(const std::string& src) { Parser parser(tokens); parser.Parse(); // ast -> triton-ir - ir::module* module = new ir::module("", ctx_); + auto ret = std::make_shared(""); Generator gen(&parser); - gen.Gen(module); - ir_.reset(module); + gen.Gen(&*ret); + return ret; } -void kernel::init_ker(){ - // triton-ir -> binary - std::unique_ptr bin; - std::unique_ptr target = dev_->make_target(); +std::tuple, + std::shared_ptr, + size_t> kernel::ir_to_bin(ir::module &ir, driver::device* dev, const options_t& opt) { // generate llvm code llvm::LLVMContext ctx; - std::string name = ir_->get_function_list()[0]->get_name(); + std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); // optimizations + std::unique_ptr target = dev->make_target(); bool cts_use_async = target->as_nvidia()->sm() >= 80; // create passes codegen::analysis::align align; @@ -162,73 +146,61 @@ void kernel::init_ker(){ codegen::transform::coalesce coalesce(&align, &layouts); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); // run passes - dce.run(*ir_); - pipeline.run(*ir_); - dce.run(*ir_); - disassociate.run(*ir_); - dce.run(*ir_); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - peephole.run(*ir_); - dce.run(*ir_); + dce.run(ir); + pipeline.run(ir); + dce.run(ir); + disassociate.run(ir); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + peephole.run(ir); + dce.run(ir); if(target->is_gpu()) - cts.run(*ir_); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - coalesce.run(*ir_); - dce.run(*ir_); - align.run(*ir_); - dce.run(*ir_); + cts.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + coalesce.run(ir); + dce.run(ir); + align.run(ir); + dce.run(ir); if(target->is_gpu()){ - reassociate.run(*ir_); - cts.run(*ir_); + reassociate.run(ir); + cts.run(ir); } - dce.run(*ir_); -// ir::print(*ir_, std::cout); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - peephole.run(*ir_); - dce.run(*ir_); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - swizzle.run(*ir_); - liveness.run(*ir_); - allocation.run(*ir_); - shared_mem_ = allocation.allocated_size(); -// if(allocation.allocated_size() > dev_->max_shared_memory()) -// throw exception::out_of_shared_memory(); - barriers.run(*ir_); - isel.visit(*ir_, *llvm); - //if(res->spilled() > 256) - // throw exception::out_of_registers(); - mod_.reset(driver::module::create(dev_, std::move(llvm))); - ker_.reset(driver::kernel::create(&*mod_, name.c_str())); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + peephole.run(ir); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + swizzle.run(ir); + liveness.run(ir); + allocation.run(ir); + barriers.run(ir); + isel.visit(ir, *llvm); + std::shared_ptr mod(driver::module::create(dev, std::move(llvm))); + std::shared_ptr ker(driver::kernel::create(&*mod, name.c_str())); + size_t shared_mem = allocation.allocated_size(); + return std::make_tuple(mod, ker, shared_mem); } -void kernel::init_sig() { - ir::function* fn = ir_->get_function_list()[0]; - ir::function_type* ty = fn->get_fn_type(); - for(size_t i = 0; i < ty->get_num_params(); i++){ - sig_.push_back(convert(ty->get_param_ty(i))); - if(!fn->has_attr(i+1)) - continue; - } -} - -kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev): +kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev, const std::map &attrs): opt(opt), dev_(dev) { - init_ir(preheader() + src); - init_ker(); - init_sig(); - for(auto arg: ir_->get_function_list()[0]->args()) - arg_names_.push_back(arg->get_name()); + // compile to Triton IR + ir_ = src_to_ir(src, opt); + // add attributes + for(const auto&x: attrs) + ir_->get_function_list()[0]->add_attr(x.first, x.second); + // compile to binary + std::tie(mod_, ker_, shared_mem_) = ir_to_bin(*ir_, dev, opt); } -void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector& _grid) const{ +void kernel::operator()(const std::string& args, driver::stream *stream, const std::vector& _grid) const{ // set grid if(_grid.size() > 3) throw std::runtime_error("grid size must be no greater than 3"); @@ -236,7 +208,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; // enqueue - stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_); + stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_); } std::string kernel::get_asm(asm_mode_t mode) { @@ -282,124 +254,124 @@ std::string kernel::get_asm(asm_mode_t mode) { /* --------------------------------- */ /* --------------------------------- */ -void function::do_loop_nest(std::vector const & ranges, - std::function const &)> const & f){ - size_t D = ranges.size(); - std::vector values(D, 0); - size_t i = D - 1; - while(true){ - f(values); - while(values[i]++ == ranges[i] - 1){ - if(i == 0) - return; - values[i--] = 0; - } - i = D - 1; options_t opt; - - } -} -void function::init_kernels(const std::string& src, const options_t& opt, - const autotune_vals_t& confs, driver::device *device) { - // list of all possible configs - // just augment `opt` with each define of `confs` - // and override warp count - size_t num_opts = std::max(confs.size(), (size_t)1); - std::vector opts(num_opts, opt); - for(size_t i = 0; i < confs.size(); i++){ - opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end()); - opts[i].num_warps = confs[i].second; - } - // compile all possible configs - // compilation errors (e.g., too much shared mem) - // will populate `err` - std::vector> err; - for(const options_t& opt: opts) { - try{ - kernels_.push_back({opt, std::make_shared(src, opt, device)}); - }catch(const exception::base& e){ - err.push_back({opt, e.what()}); - } - } - // throw an exception if `err` is not empty - if(kernels_.empty()){ - std::ostringstream dbg; - dbg << "Auto-Tuner could not find any valid configuration:" << std::endl; - for(auto x: err){ - dbg << "[ "; - dbg << x.first.num_warps << ", "; - dbg << "{ "; - for(const auto& y: x.first.defines) - dbg << '"' << y.first << "\"= \"" << y.second << "\", "; - dbg << " } ] -> " << x.second << std::endl; - } - throw exception::no_valid_configuration(dbg.str()); - } -} - -kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream* stream) { - // fast path -- no autotuning necessary - if(kernels_.size() == 1) - return &*kernels_.begin()->second; - // auto-tuning key - std::vector key(key_idxs_.size()); - for(size_t i = 0; i < key.size(); i++){ - int idx = key_idxs_[i]; - std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]); - } - auto it = cache_.find(key); - if(it != cache_.end()) - return it->second; - // run auto-tuner - double best_ts = INFINITY; - kernel* ret = nullptr; - for(auto &x : kernels_){ - kernel* current = &*x.second; - auto grid = grid_fn(x.first); - while(grid.size() < 3) - grid.push_back(1); - double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); }, - stream, 5, 20); - ret = (ts < best_ts) ? current : ret; - best_ts = std::min(ts, best_ts); - } - stream->synchronize(); - it = cache_.insert({key, ret}).first; - return it->second; -} function::function(const std::string& src, const options_t &opt, driver::device *device, - const autotune_vals_t& autotune_vals, const std::vector& autotune_key) { - // pre-compile all kernels - init_kernels(src, opt, autotune_vals, device); - // find indices of autotune keys - auto arg_names = kernels_.at(0).second->get_arg_names(); - for(const std::string& name: autotune_key){ - auto it = std::find(arg_names.begin(), arg_names.end(), name); - if(it == arg_names.end()) - throw std::runtime_error(name + " is not a valid argument name"); - key_idxs_.push_back(std::distance(arg_names.begin(), it)); + const std::vector &tune_confs, const std::vector& tune_key) + : src_(src), device_(device) { + // kernel options + size_t num_opts = std::max(tune_confs.size(), (size_t)1); + opts_ = std::vector(num_opts, opt); + for(size_t i = 0; i < tune_confs.size(); i++){ + opts_[i].defines.insert(tune_confs[i].defines.begin(), tune_confs[i].defines.end()); + opts_[i].num_warps = tune_confs[i].num_warps; } + std::shared_ptr ir = kernel::src_to_ir(src, opts_[0]); + std::vector args = ir->get_function_list()[0]->args(); + // signature + auto convert = [](ir::type *ty) { + if(ty->is_integer_ty(1)) return INT1_T; + if(ty->is_integer_ty(8)) return INT8_T; + if(ty->is_integer_ty(16)) return INT16_T; + if(ty->is_integer_ty(32)) return INT32_T; + if(ty->is_integer_ty(64)) return INT64_T; + if(ty->is_half_ty()) return HALF_T; + if(ty->is_float_ty()) return FLOAT_T; + if(ty->is_double_ty()) return DOUBLE_T; + if(ty->is_pointer_ty()) return BUFFER_T; + throw std::runtime_error("unknown type"); + }; + for(ir::argument* arg: args) + sig_.push_back(convert(arg->get_type())); + // find indices of autotune keys + for(const std::string& name: tune_key){ + auto pred = [&](ir::argument* arg) { return arg->get_name() == name; }; + auto it = std::find_if(args.begin(), args.end(), pred); + if(it == args.end()) + throw std::runtime_error(name + " is not a valid argument name"); + key_idxs_.push_back(std::distance(args.begin(), it)); + } + // find indices of pointer + for(size_t i = 0; i < args.size(); i++) + if(args[i]->get_type()->is_pointer_ty() || + args[i]->get_type()->is_integer_ty()) + align_idxs_.push_back(i); // argument size and offset - auto tys = kernels_.at(0).second->get_sig(); size_t curr = 0; - for(arg_type ty: tys){ + for(arg_type ty: sig_){ arg_size_.push_back(size_of(ty)); arg_off_.push_back(curr); curr += arg_size_.back(); } - - } -void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) { - runtime::kernel* fn = autotune(args, args_size, grid_fn, stream); - (*fn)(args, args_size, stream, grid_fn(fn->opt)); +uint64_t pow2_divisor(uint64_t N){ + if(N % 16 == 0) return 16; + if(N % 8 == 0) return 8; + if(N % 4 == 0) return 4; + if(N % 2 == 0) return 2; + return 1; } -void function::operator()(void* args, size_t args_size, const grid_t& grid, driver::stream* stream) { - return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream); +kernel* function::autotune(const std::string &args, const grid_fn_ty& grid_fn, driver::stream* stream) { + // align key + std::vector rt_key(align_idxs_.size(), 0); + for(size_t i = 0; i < align_idxs_.size(); i++){ + int idx = align_idxs_[i]; + uint64_t tmp = 0; + std::memcpy((void*)&tmp, (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]); + rt_key[i] = pow2_divisor(tmp); + } + // auto-tuning key + std::vector at_key(key_idxs_.size(), 0); + for(size_t i = 0; i < at_key.size(); i++){ + int idx = key_idxs_[i]; + std::memcpy((void*)&at_key[i], (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]); + } + // cache key + std::vector cache_key; + cache_key.reserve(rt_key.size() + at_key.size()); + cache_key.insert(cache_key.end(), rt_key.begin(), rt_key.end()); + cache_key.insert(cache_key.end(), at_key.begin(), at_key.end()); + auto it = cache_.find(cache_key); + if(it != cache_.end()) + return it->second; + // compile kernels + if(kernels_.find(rt_key) == kernels_.end()){ + std::map attrs; + for(size_t i = 0; i < align_idxs_.size(); i++){ + bool is_ptr = sig_[align_idxs_[i]] == BUFFER_T; + attrs.insert({align_idxs_[i] + 1, ir::attribute(is_ptr ? ir::aligned : ir::multiple_of, rt_key[i])}); + } + for(const options_t& opt: opts_) + kernels_[rt_key].emplace_back(new kernel(src_, opt, device_, attrs)); + } + // run auto-tuner + double best_ts = INFINITY; + auto& kernels = kernels_.at(rt_key); + kernel* ret = nullptr; + if(kernels.size() == 1) + ret = &*kernels.back(); + else{ + for(auto ¤t : kernels_.at(rt_key)){ + auto grid = grid_fn(current->opt); + while(grid.size() < 3) + grid.push_back(1); + double ts = tools::bench([&]() { (*current)(args, stream, grid); }, + stream, 5, 20); + ret = (ts < best_ts) ? &*current : ret; + best_ts = std::min(ts, best_ts); + } + stream->synchronize(); + } + it = cache_.insert({cache_key, ret}).first; + return it->second; +} + +void function::operator()(const std::string& args, const grid_fn_ty& grid_fn, driver::stream *stream) { + runtime::kernel* fn = autotune(args, grid_fn, stream); + (*fn)(args, stream, grid_fn(fn->opt)); } diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 25da0f68e..b79030c40 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -1,12 +1,19 @@ import triton import torch +import os -# square benchmarks +def rounded_linspace(low, high, steps, div): + ret = torch.linspace(low, high, steps) + ret = (ret.int() + div - 1) // div * div + ret = torch.unique(ret) + return list(map(int, ret)) + +# Square benchmarks nt = {False: "n", True: "t"} square_confs = [ triton.testing.Benchmark( x_names=["M", "N", "K"], - x_vals=[512 * i for i in range(1, 16)], + x_vals=rounded_linspace(512, 8192, 17, 128), y_name="provider", y_vals=["torch", "triton", "cutlass"], y_lines=["Torch", "Triton", "CUTLASS"], @@ -17,16 +24,29 @@ square_confs = [ ) for AT in [False, True] for BT in [False, True] ] -@triton.testing.perf_report(square_confs) -def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): - import os +# Transformer training benchmarks +transformer_confs = [ + triton.testing.Benchmark( + x_names=[x], + x_vals = rounded_linspace(NK//16, NK, 33, 128), + y_name="provider", + y_vals=["torch", "triton", "cutlass"], + y_lines=["Torch", "Triton", "CUTLASS"], + ylabel="TFLOPS", + loglog=False, + plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", + args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} + ) for NK in [8192]\ + for i, x in enumerate(["N", "K"])\ + for M in [2048] +] +@triton.testing.perf_report(square_confs) +def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40): a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) - if AT: - a = a.t() - if BT: - b = b.t() + if AT: a = a.t() + if BT: b = b.t() num_flops = 2 * M * N * K if provider == "torch": torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) @@ -40,7 +60,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): import subprocess import tempfile import pandas as pd - # run program specified by CUTLASS_PROFILER env variable layout_a = "column" if AT else "row" layout_b = "column" if BT else "row" @@ -61,6 +80,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): f"--warmup-iterations={warmup}", f"--profiling-iterations={rep}", f"--output={fname}", + "--dist=uniform,min:0,max:1,scale:-1", "--verbose=false", ] # run cmd @@ -70,6 +90,3 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): cutlass_tflops = max(df_c["GFLOPs"]) / 1e3 return cutlass_tflops return None - -if __name__ == "__main__": - bench_op.run() diff --git a/python/bench/run.py b/python/bench/run.py index 17784947a..df52a07c4 100644 --- a/python/bench/run.py +++ b/python/bench/run.py @@ -38,4 +38,4 @@ def main(args): run_all(args.result_dir, args.with_plots, args.names) if __name__ == '__main__': - main(sys.argv[1:]) \ No newline at end of file + main(sys.argv[1:]) diff --git a/python/src/torch/utils.cc b/python/src/torch/utils.cc index 8deb8d4f2..c7bf74105 100644 --- a/python/src/torch/utils.cc +++ b/python/src/torch/utils.cc @@ -5,38 +5,16 @@ #include #include -std::map> tt_devices; -std::map> tt_streams; - namespace torch_utils { -void register_device(int64_t dev_id) { - if (tt_devices.find(dev_id) != tt_devices.end()) - return; - triton::driver::device *device; - if (dev_id >= 0) { - CUdevice handle; - triton::driver::dispatch::cuDeviceGet(&handle, dev_id); - device = new triton::driver::cu_device(handle, false); - } else - device = new triton::driver::host_device(); - tt_devices[dev_id].reset(device); +uint64_t cu_device(int64_t dev_id) { + CUdevice handle; + triton::driver::dispatch::cuDeviceGet(&handle, dev_id); + return (uint64_t)handle; } -void register_stream(int64_t dev_id) { - if (tt_streams.find(dev_id) != tt_streams.end()) - return; - triton::driver::stream *stream; - if (dev_id >= 0) { - CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream(); - stream = new triton::driver::cu_stream(handle, false); - } else - stream = new triton::driver::host_stream(); - tt_streams[dev_id].reset(stream); -} - -void synchronize(int64_t dev_id) { - tt_streams[dev_id]->synchronize(); +uint64_t cu_stream(int64_t dev_id) { + return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream(); } void set_device(int64_t dev_id) { @@ -44,23 +22,11 @@ void set_device(int64_t dev_id) { C10_CUDA_CHECK(cudaSetDevice(dev_id)); } -torch::Tensor move_out_of_pool(torch::Tensor x) { - if (x.nbytes() == 0) - return torch::empty_like(x); - void *data; - cudaMalloc(&data, x.nbytes()); - auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options()); - ret.copy_(x); - return ret; -} - } // namespace torch_utils void init_torch_utils(pybind11::module &m) { pybind11::module subm = m.def_submodule("torch_utils"); - subm.def("register_device", &torch_utils::register_device); - subm.def("register_stream", &torch_utils::register_stream); + subm.def("cu_device", &torch_utils::cu_device); + subm.def("cu_stream", &torch_utils::cu_stream); subm.def("set_device", &torch_utils::set_device); - subm.def("synchronize", &torch_utils::synchronize); - subm.def("move_out_of_pool", &torch_utils::move_out_of_pool); } \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index c208e37a7..6eea1717e 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,10 +1,4 @@ #include "triton/driver/stream.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/lang/code_gen.h" -#include "triton/lang/cpp.h" -#include "triton/lang/parser.h" -#include "triton/runtime/arg.h" #include "triton/runtime/function.h" #include #include @@ -13,72 +7,22 @@ #include #include +namespace py = pybind11; + using namespace triton; namespace rt = triton::runtime; namespace drv = triton::driver; -namespace lng = triton::lang; -std::unordered_map opt_cache_; -std::map> id_fn_map; -extern std::map> tt_devices; -extern std::map> tt_streams; +/*****************************************************************************/ +/* Python bindings for triton::tools */ +/*****************************************************************************/ -/* Function utilities */ - -void register_fn(int op_id, int dev_id, - const std::string &src, const rt::options_t &opt, - const rt::function::autotune_vals_t &autotune_vals, - const std::vector &autotune_key) { - if (id_fn_map.find(op_id) == id_fn_map.end()) { - id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key)); - } - for (const auto &k : id_fn_map[op_id]->get_kernels()) { - const rt::options_t *opt = &k.first; - pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference); - for (auto x : opt->defines) - if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) - obj.attr(x.first.c_str()) = std::stoi(x.second); - opt_cache_[&k.second->opt] = obj; - } -} - -void delete_fn(int op_id) { - id_fn_map.erase(op_id); -} - -void cleanup() { - id_fn_map.clear(); - opt_cache_.clear(); -} - -size_t make_op_id() { - return id_fn_map.size(); -} - -std::vector get_fn_signature(size_t op_id) { - return id_fn_map[op_id]->get_kernels()[0].second->get_sig(); -} - -// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments -// as a string constructed with struct.pack in python -void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) { - rt::function *fn = id_fn_map.at(op_id).get(); - (*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]); -} - -pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) { - rt::function *fn = id_fn_map.at(op_id).get(); - auto wrapper = [&grid](const rt::options_t &opt) { - pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference); - for (auto x : opt.defines) - if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) - obj.attr(x.first.c_str()) = std::stoi(x.second); - return grid(*obj.cast()); - }; - rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]); - return opt_cache_.at(&kernel->opt); -} +/*! + @brief Function for extracting kernels out of a given source-string + This can be important to enable pre-processor macros (or tunable parameters) that should only + be defined within the scope of a single kernel function +*/ std::string extract_kernels(const std::string &str, const std::vector &names) { if (names.empty()) return str; @@ -94,50 +38,82 @@ std::string extract_kernels(const std::string &str, const std::vectorstr(1); kernels.push_back(std::make_tuple(name, pos, len)); } - + // check that all the kernels provided actually exist for (const std::string &name : names) { - // check that str matches any string in kernels using std::any_of auto pred = [&name](const std::tuple &t) { return std::get<0>(t) == name; }; bool found = std::any_of(kernels.begin(), kernels.end(), pred); if (!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); } - - // extract functions + // simple parsing logic to extract the declaration and body of each specified kernel std::string ret; for (const auto &k : kernels) { std::string name; int pos, len; std::tie(name, pos, len) = k; - if (std::find(names.begin(), names.end(), name) != names.end()) { - std::string def = str.substr(pos, str.size() - pos); - int count, pos; - // skip over declaration - count = 1; - pos = def.find('('); - while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { - count += def[pos] == '('; - count -= def[pos] == ')'; - } - // skip over definition - count = 1; - pos = def.find('{', pos); - while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { - count += def[pos] == '{'; - count -= def[pos] == '}'; - } - ret += def.substr(0, pos); - ret += '\n'; + if (std::find(names.begin(), names.end(), name) == names.end()) + continue; + std::string def = str.substr(pos, str.size() - pos); + // skip over declaration + // by finding matching ')' for first '(' + int count = 1; + pos = def.find('('); + while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { + count += def[pos] == '('; + count -= def[pos] == ')'; } + // skip over definition + // by finding matching '{' for first '}' + count = 1; + pos = def.find('{', pos); + while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { + count += def[pos] == '{'; + count -= def[pos] == '}'; + } + ret += def.substr(0, pos); + ret += '\n'; } - return ret; } -void init_triton(pybind11::module &m) { - pybind11::module subm = m.def_submodule("triton"); - // bindings for triton classes - pybind11::enum_(subm, "arg_type") +void init_triton_tools(py::module &&m) { + m.def("extract_kernels", &extract_kernels); +} + +/*****************************************************************************/ +/* Python bindings for triton::driver */ +/*****************************************************************************/ + +void init_triton_driver(py::module &&m) { + // base device + py::class_(m, "device"); + // cuda device + py::class_(m, "cu_device") + .def(py::init()); + // host device + py::class_(m, "host_device") + .def(py::init<>()); + + // base stream + py::class_(m, "stream"); + // host stream + py::class_(m, "host_stream") + .def(py::init<>()); + // cuda stream + py::class_(m, "cu_stream") + // py doesn't support opaque pointer (e.g., CUstream) so + // we assume it has been converted to uint64_t + .def(py::init([](uint64_t handle, bool take_ownership) { + return std::unique_ptr(new driver::cu_stream((CUstream)handle, take_ownership)); + })); +} + +/*****************************************************************************/ +/* Python bindings for triton::runtime */ +/*****************************************************************************/ +void init_triton_runtime(py::module &&m) { + // argument type + py::enum_(m, "arg_type") .value("int1", rt::INT1_T) .value("int8", rt::INT8_T) .value("int16", rt::INT16_T) @@ -147,23 +123,38 @@ void init_triton(pybind11::module &m) { .value("float", rt::FLOAT_T) .value("double", rt::DOUBLE_T) .value("buffer", rt::BUFFER_T); - - pybind11::enum_(subm, "asm_mode") + // assembly mode + py::enum_(m, "asm_mode") .value("ptx", rt::ASM_NV_PTX) .value("sass", rt::ASM_NV_SASS); - - pybind11::class_(subm, "options", pybind11::dynamic_attr()) - .def(pybind11::init<>()) + // compilation options + py::class_(m, "options", py::dynamic_attr()) + .def(py::init<>()) .def_readwrite("defines", &rt::options_t::defines) - .def_readwrite("num_warps", &rt::options_t::num_warps); + .def_readwrite("num_warps", &rt::options_t::num_warps) + .def("__getattr__", [](rt::options_t *opt, const std::string &name) { + return opt->D(name); + }); + // kernel + py::class_(m, "kernel") + .def("__call__", &rt::kernel::operator()) + .def_readonly("opt", &rt::kernel::opt); + // tune conf + py::class_(m, "config") + .def(py::init, int>(), + py::arg("defines") = std::map(), + py::arg("num_warps")); - // hooks into triton constructs since frameworks may not use pybind11 - subm.def("extract_kernels", &extract_kernels); - subm.def("get_fn_signature", &get_fn_signature); - subm.def("register_fn", ®ister_fn); - subm.def("delete_fn", &delete_fn); - subm.def("make_op_id", &make_op_id); - subm.def("cleanup", &cleanup); - subm.def("autotune", &autotune, pybind11::return_value_policy::reference); - subm.def("launch_kernel", &launch_kernel); + // function + py::class_(m, "function") + .def(py::init &, const std::vector &>()) + .def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal) + .def("signature", &rt::function::get_signature); +} + +void init_triton(py::module &m) { + py::module subm = m.def_submodule("triton"); + init_triton_driver(std::move(subm.def_submodule("driver"))); + init_triton_runtime(std::move(subm.def_submodule("runtime"))); + init_triton_tools(std::move(subm.def_submodule("tools"))); } diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index aaf738acc..a621d347d 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -50,8 +50,9 @@ import torch def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE): DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] torch.manual_seed(0) + defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)} triton.ops._matmul._kernels = dict() - triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)] + triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)] if M is None: M = TM if N is None: diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 206a2dd9f..1f83d2f9c 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -1,27 +1,21 @@ import os import struct from typing import Optional, Dict, List - import torch # C bindings import triton._C.libtriton.triton as _triton import triton._C.libtriton.torch_utils as _torch_utils -# Make sure internal C resources are cleaned up upon exit -import atexit - -@atexit.register -def cleanup(): - _triton.cleanup() codes = { - _triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q', - _triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P' + _triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I', + _triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f', + _triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P' } def th_to_triton(obj): tys = { - torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half', - torch.float32: 'float', torch.float64: 'double' + torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\ + torch.float16: 'half', torch.float32: 'float', torch.float64: 'double' } if isinstance(obj, torch.dtype): return tys[obj] @@ -30,69 +24,54 @@ def th_to_triton(obj): def cdiv(a, b): return (a + b - 1) // b -def synchronize(device): - dev_id = device.index - dev_id = -1 if dev_id is None else dev_id - _torch_utils.synchronize(dev_id) - -def read(path, kernel_names:Optional[List]=None): +def read(path, kernel_names: Optional[List] = None): if kernel_names is None: kernel_names = [] with open(path, 'r') as f: source = f.read() - source = _triton.extract_kernels(source, kernel_names) + source = _triton.tools.extract_kernels(source, kernel_names) return source -class kernel: - def __init__(self, - src, - device, - defines: Optional[Dict]=None, - num_warps:int=4, - autotune_vals:Optional[List]=None, - autotune_key:Optional[List]=None): +config = _triton.runtime.config +class kernel: + def __init__(self, src, device, defines: Optional[Dict] = None, num_warps: int = 4, + autotune_vals: Optional[List] = None, autotune_key: Optional[List] = None): if defines is None: defines = {} if autotune_vals is None: autotune_vals = [] if autotune_key is None: autotune_key = [] - - # check if src is empty if src == '': raise ValueError('Kernel source code is empty') self.src = src - self.opt = _triton.options() - self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} - self.opt.num_warps = num_warps # device assert device.type in ['cuda', 'cpu'] if device.type == 'cuda': - self.device = torch.cuda.current_device() if device.index is None else device.index + self.device_id = torch.cuda.current_device() if device.index is None else device.index + self.device = _triton.driver.cu_device(_torch_utils.cu_device(self.device_id), False) + self.stream = _triton.driver.cu_stream(_torch_utils.cu_stream(self.device_id), False) if device.type == 'cpu': - self.device = -1 - _torch_utils.register_device(self.device) - _torch_utils.register_stream(self.device) - # C++ function wrapper - self.op_id = _triton.make_op_id() - _torch_utils.set_device(self.device) - _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) - # debug mode - self.is_debug = 'TRITON_DEBUG' in os.environ - # signature - arg_types = _triton.get_fn_signature(self.op_id) - self.tys = ''.join([codes[x] for x in arg_types]) + self.device_id = -1 + self.device = _triton.driver.host_device() + self.device = _triton.driver.host_stream() + _torch_utils.set_device(self.device_id) + # function + self.opt = _triton.runtime.options() + self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} + self.opt.num_warps = num_warps + # autotune_vals = [({}, 4)] + self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_vals, autotune_key) + self.tys = ''.join([codes[x] for x in self.fn.signature()]) def __call__(self, *args, grid): - _torch_utils.set_device(self.device) + # make sure that the executing thread is on the right device + _torch_utils.set_device(self.device_id) # pack parameters into a byte buffer params = struct.pack(self.tys, *args) - opt = _triton.autotune(self.op_id, self.device, params, grid) + kernel = self.fn.autotune(params, grid, self.stream) # run kernel - grid = grid(opt) - grid_0 = grid[0] - grid_1 = 1 if len(grid) < 2 else grid[1] - grid_2 = 1 if len(grid) < 3 else grid[2] - _triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) + grid = grid(kernel.opt) + kernel(params, self.stream, grid) diff --git a/python/triton/ops/blocksparse/matmul.c b/python/triton/ops/blocksparse/matmul.c index a45d072b3..e3522ec29 100644 --- a/python/triton/ops/blocksparse/matmul.c +++ b/python/triton/ops/blocksparse/matmul.c @@ -1,17 +1,17 @@ -__global__ void NAME(TYPE *A __readonly __noalias __aligned(16), - TYPE *B __readonly __noalias __aligned(16), - TYPE *C __noalias __aligned(16), - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc __multipleof(8), - long stride_za __multipleof(8), - long stride_zb __multipleof(8), - long stride_zc __multipleof(8), - long stride_ha __multipleof(8), - long stride_hb __multipleof(8), - long stride_hc __multipleof(8), +__global__ void NAME(TYPE *A __readonly __noalias, + TYPE *B __readonly __noalias, + TYPE *C __noalias, + int lda, + int ldb, + int ldc, + long stride_za, + long stride_zb, + long stride_zc, + long stride_ha, + long stride_hb, + long stride_hc, int DS0, int DS1, - int SDD_K __multipleof(16), + int SDD_K, int SDD_off_width, int *lut, int *locks, int nlocks) { /* ---------------- */ diff --git a/python/triton/ops/blocksparse/softmax.c b/python/triton/ops/blocksparse/softmax.c index 625f4a6ac..8b8c9506a 100644 --- a/python/triton/ops/blocksparse/softmax.c +++ b/python/triton/ops/blocksparse/softmax.c @@ -1,17 +1,16 @@ -__global__ void forward(TYPE *X __readonly __noalias __aligned(16), +__global__ void forward(TYPE *X __readonly __noalias, float scale, - int *LUT __readonly __noalias __aligned(16), - TYPE *RPE __readonly __noalias __aligned(16), - TYPE *KP_M __readonly __noalias __aligned(16), - TYPE *ATTN_M __readonly __noalias __aligned(16), + int *LUT __readonly __noalias, + TYPE *RPE __readonly __noalias, + TYPE *KP_M __readonly __noalias, + TYPE *ATTN_M __readonly __noalias, int sizemax, - long stride_zx __multipleof(4), - long stride_zrpe __multipleof(BLOCK), - int stride_hrpe __multipleof(BLOCK), - int stride_srpe __multipleof(BLOCK), - int stride_zkpm __multipleof(BLOCK), - int stride_zattnm __multipleof(BLOCK)) -{ + long stride_zx, + long stride_zrpe, + int stride_hrpe, + int stride_srpe, + int stride_zkpm, + int stride_zattnm) { int pidhm = get_program_id(0); int pidz = get_program_id(1); // create index ranges @@ -97,14 +96,13 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16), *? (check)px = y / ysum; } -__global__ void backward(TYPE *X __readonly __noalias __aligned(16), +__global__ void backward(TYPE *X __readonly __noalias, float scale, - TYPE *DX __readonly __noalias __aligned(16), + TYPE *DX __readonly __noalias, int *LUT, int sizemax, - long stride_zx __multipleof(BLOCK), - long stride_zdx __multipleof(BLOCK)) -{ + long stride_zx, + long stride_zdx) { int pidhm = get_program_id(0); int pidz = get_program_id(1); // create index ranges diff --git a/python/triton/ops/conv.c b/python/triton/ops/conv.c index d115ff540..f2c9e899a 100644 --- a/python/triton/ops/conv.c +++ b/python/triton/ops/conv.c @@ -1,126 +1,131 @@ -__global__ void conv(TYPE *A __noalias __readonly __aligned(16), - TYPE *B __noalias __readonly __aligned(16), - TYPE *C __noalias __aligned(16), - float alpha, - // equivalent matmul - int M, int N, int K, - // convolution properties - int pad_h, int pad_w, int stride_h, int stride_w, - // pointer increment - int *ADELTA, - // memory strides - int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8), - int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8), - int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) { - // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int ridz = get_program_id(2); - int gridx = M / TM; - int gridy = N / TN; - int rid = ridx + ridy * gridx; - ridx = rid / gridy; - ridy = rid % gridy; - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; - // reduction splitting - K = K / TZ; - int rk[TK] = ridz * K + 0 ... TK; +__global__ void conv(TYPE *A __noalias __readonly, + TYPE *B __noalias __readonly, + TYPE *C __noalias, + float alpha, + // equivalent matmul + int M, int N, int K, + // convolution properties + int pad_h, int pad_w, int stride_h, int stride_w, + // pointer increment + int *ADELTA, + // memory strides + int lda_z, int lda_ci, int lda_h, int lda_w, + int ldb_ci, int ldb_r, int ldb_s, int ldb_co, + int ldc_z, int ldc_co, int ldc_p, int ldc_q) { + // prologue + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int ridz = get_program_id(2); + int gridx = M / TM; + int gridy = N / TN; + int rid = ridx + ridy * gridx; + ridx = rid / gridy; + ridy = rid % gridy; + int rm[TM] = ridx * TM + 0 ... TM; + int rn[TN] = ridy * TN + 0 ... TN; + // reduction splitting + K = K / TZ; + int rk[TK] = ridz * K + 0 ... TK; - // unpack aggregate rows - // m = (z, p, q) - int rq[TM] = rm % QQ; - int rzp[TM] = rm / QQ; - int rp[TM] = rzp % PP; - int rz[TM] = rzp / PP; - // unpack aggregate reduction - // k = (ci, r, s) - int rs [TK] = rk % SS; - int rcir[TK] = rk / SS; - int rr [TK] = rcir % RR; - int rci [TK] = rcir / RR; + // unpack aggregate rows + // m = (z, p, q) + int rq[TM] = rm % QQ; + int rzp[TM] = rm / QQ; + int rp[TM] = rzp % PP; + int rz[TM] = rzp / PP; + // unpack aggregate reduction + // k = (ci, r, s) + int rs[TK] = rk % SS; + int rcir[TK] = rk / SS; + int rr[TK] = rcir % RR; + int rci[TK] = rcir / RR; - // padding / striding - int rh_0[TM] = rp * stride_h - pad_h; - int rw_0[TM] = rq * stride_w - pad_w; - int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :]; - int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :]; + // padding / striding + int rh_0[TM] = rp * stride_h - pad_h; + int rw_0[TM] = rq * stride_w - pad_w; + int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :]; + int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :]; - // pointers to lhs - int offa[TM, TK] = rz [:, newaxis] * lda_z + - rci[newaxis, :] * lda_ci + - rh * lda_h + - rw * 1; - TYPE* pa[TM, TK] = A + offa; - int* padelta[TK] = ADELTA + rk; - // pointers to rhs - int offb[TK, TN] = rci[:, newaxis] * ldb_ci + - rr [:, newaxis] * ldb_r + - rs [:, newaxis] * ldb_s + - rn [newaxis, :] * 1; - TYPE* pb[TK, TN] = B + offb; + // pointers to lhs + int offa[TM, TK] = rz[:, newaxis] * lda_z + + rci [newaxis, :] * lda_ci + + rh * lda_h + + rw * 1; + TYPE *pa[TM, TK] = A + offa; + int *padelta[TK] = ADELTA + rk; + // pointers to rhs + int offb[TK, TN] = rci[:, newaxis] * ldb_ci + + rr + [:, newaxis] * ldb_r + + rs + [:, newaxis] * ldb_s + + rn [newaxis, :] * 1; + TYPE *pb[TK, TN] = B + offb; - // prefetches operands - bool checkam[TM, TK] = rm[:, newaxis] < M; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - bool checkb[TK, TN] = rk[:, newaxis] < K; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - int total = 0; + // prefetches operands + bool checkam[TM, TK] = rm[:, newaxis] < M; + bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; + bool checkb[TK, TN] = rk[:, newaxis] < K; + TYPE a[TM, TK] = checka ? *pa : 0; + TYPE b[TK, TN] = checkb ? *pb : 0; + int total = 0; - // reduction loop - float acc[TM, TN] = 0; - for(int k = K; k > 0; k -= TK){ - acc += a @ b; - // increment A - int adelta[TK] = *padelta; - padelta += TK; - pa += adelta[newaxis, :]; - // bounds-checking A - rk += TK; - rs = rk % SS; - rcir = rk / SS; - rr = rcir % RR; - rh = rh_0[:, newaxis] + rr[newaxis, :]; - rw = rw_0[:, newaxis] + rs[newaxis, :]; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - // increment B - pb += TK * ldb_s; - // bounds-checking B - bool checkb[TK, TN] = k > TK; - a = checka ? *pa : 0; - b = *?(checkb)pb; - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; + // reduction loop + float acc[TM, TN] = 0; + for (int k = K; k > 0; k -= TK) { + acc += a @b; + // increment A + int adelta[TK] = *padelta; + padelta += TK; + pa += adelta [newaxis, :]; + // bounds-checking A + rk += TK; + rs = rk % SS; + rcir = rk / SS; + rr = rcir % RR; + rh = rh_0[:, newaxis] + rr [newaxis, :]; + rw = rw_0[:, newaxis] + rs [newaxis, :]; + bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; + // increment B + pb += TK * ldb_s; + // bounds-checking B + bool checkb[TK, TN] = k > TK; + a = checka ? *pa : 0; + b = *? (checkb)pb; + } + acc = acc * alpha; + TYPE c[TM, TN] = acc; - // epilogue - rm = ridx * TM + 0 ... TM; - rn = ridy * TN + 0 ... TN; - rq = rm % QQ; - rzp = rm / QQ; - rp = rzp % PP; - rz = rzp / PP; - int offc[TM, TN] = rz [:, newaxis] * ldc_z + - rn [newaxis, :] * ldc_co+ - rp [:, newaxis] * ldc_p + - rq [:, newaxis] * 1; - TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N; + // epilogue + rm = ridx * TM + 0 ... TM; + rn = ridy * TN + 0 ... TN; + rq = rm % QQ; + rzp = rm / QQ; + rp = rzp % PP; + rz = rzp / PP; + int offc[TM, TN] = rz[:, newaxis] * ldc_z + + rn [newaxis, :] * ldc_co + + rp + [:, newaxis] * ldc_p + + rq + [:, newaxis] * 1; + TYPE *pc[TM, TN] = C + offc; + bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N; -#if (TZ==1) - *?(checkc) pc = c; +#if (TZ == 1) + *? (checkc)pc = c; #else - // accumulate partial result using spin-locks - int *plock = locks + rid; - int *pcount = plock + get_num_programs(0) * get_num_programs(1); - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % TZ); - atomic_xchg(plock, 0); + // accumulate partial result using spin-locks + int *plock = locks + rid; + int *pcount = plock + get_num_programs(0) * get_num_programs(1); + for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)) + ; + int count = *pcount; + if (count == 0) + *? (checkc)pc = c; + else + *? (checkc)pc = c + *? (checkc)pc; + atomic_xchg(pcount, (count + 1) % TZ); + atomic_xchg(plock, 0); #endif } \ No newline at end of file diff --git a/python/triton/ops/cross_entropy.c b/python/triton/ops/cross_entropy.c index 2767fae1a..2de793448 100644 --- a/python/triton/ops/cross_entropy.c +++ b/python/triton/ops/cross_entropy.c @@ -1,8 +1,4 @@ -__global__ void forward(TYPE *logit __aligned(16), - TYPE *modified_logit __aligned(16), - long *indices __readonly, - TYPE *result __aligned(16), - int n_cols __multipleof(N_COLS_MULT)) { +__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) { int row = get_program_id(0); bool check[TILE] = ((0 ... TILE) < n_cols); @@ -19,10 +15,7 @@ __global__ void forward(TYPE *logit __aligned(16), *(result + row) = *(modified_logit + (local_ind + n_cols * row)); } -__global__ void backward(TYPE *neg_logprobs __aligned(16), - long *indices __aligned(16), - TYPE *dneg_logprobs __aligned(16), - int n_cols __multipleof(N_COLS_MULT)) { +__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) { int row = get_program_id(0); // pointer arithmetic diff --git a/python/triton/ops/matmul.c b/python/triton/ops/matmul.c index 3410649c7..747340bea 100644 --- a/python/triton/ops/matmul.c +++ b/python/triton/ops/matmul.c @@ -1,16 +1,12 @@ #define STM 8 #define STN 8 -__global__ void matmul(TYPE *A __noalias __readonly __aligned(16), - TYPE *B __noalias __readonly __aligned(16), - TYPE *C __noalias __aligned(16), +__global__ void matmul(TYPE *A __noalias __readonly, + TYPE *B __noalias __readonly, + TYPE *C __noalias, float alpha, - int M, - int N, - int K __multipleof(16), - int lda __multipleof(LDA_POW2_DIV), - int ldb __multipleof(LDB_POW2_DIV), - int ldc __multipleof(LDC_POW2_DIV), + int M, int N, int K, + int lda, int ldb, int ldc, int *locks) { // prologue int pid = get_program_id(0); diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 6033ca780..a8925e34b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -6,18 +6,18 @@ class _matmul(torch.autograd.Function): src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c")) _DEFAULT_CONFIGS = [ - ({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4), - ({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4), - ({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4), - ({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4), - ({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4), - ({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4), - ({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2), - ({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2), - # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4), - # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4), - # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4), - # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4), + triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4), + triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2), + triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2), + triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4), + triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4), ] _CONFIGS = _DEFAULT_CONFIGS diff --git a/python/triton/testing.py b/python/triton/testing.py index 0148eb6b9..510375b08 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,4 +1,5 @@ import torch +import os def sparsify_tensor(x, mask, block): ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) @@ -77,8 +78,12 @@ class Mark: df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv")) def run(self, result_path, with_plot): - for bench in self.benchmarks: - self._run(bench, result_path, with_plot) + with open(os.path.join(result_path, "results.html"), "w") as html: + html.write("\n") + for bench in self.benchmarks: + self._run(bench, result_path, with_plot) + html.write(f"\n") + html.write("\n") def perf_report(benchmarks): wrapper = lambda fn: Mark(fn, benchmarks) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py new file mode 100644 index 000000000..9163e4efd --- /dev/null +++ b/python/tutorials/01-vector-add.py @@ -0,0 +1,76 @@ +import torch +import triton + +# source-code for Triton compute kernel +# here we just copy-paste the above code without the extensive comments. +# you may prefer to store it in a .c file and load it from there instead. +_src = """ +__global__ void add(float* z, float* x, float* y, int N){ + // program id + int pid = get_program_id(0); + // create arrays of pointers + int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK; + float* pz[BLOCK] = z + offset; + float* px[BLOCK] = x + offset; + float* py[BLOCK] = y + offset; + // bounds checking + bool check[BLOCK] = offset < N; + // write-back + *?(check)pz = *?(check)px + *?(check)py; +} + """ +# This function returns a callable `triton.kernel` object +# created from the above source code. +# For portability, we maintain a cache of kernels for different `torch.device` +# We compile the kernel with -DBLOCK=1024 +_kernels = dict() + +def make_add_kernel(device): + if device not in _kernels: + defines = {'BLOCK': 1024} + autotune_vals = [({'BLOCK': '1024'}, 4), ({'BLOCK': '2048'}, 4)] + autotune_key = ["N"] + _kernels[device] = triton.kernel(_src, device=device, defines=defines, autotune_vals=autotune_vals, + autotune_key=autotune_key) + return _kernels[device] + +# This is a standard torch custom autograd Function +# The only difference is that we can now use the above kernel +# in the `forward` and `backward` functions.` +class _add(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + # constraints of the op + assert x.dtype == torch.float32 + # *allocate output* + z = torch.empty_like(x) + # *create launch grid*: + # this is a function which takes compilation parameters `opt` + # as input and returns a tuple of int (i.e., launch grid) for the kernel. + # triton.cdiv is a shortcut for ceil division: + # triton.cdiv(a, b) = (a + b - 1) // b + grid = lambda opt: (triton.cdiv(z.shape[0], opt.BLOCK), ) + # *launch kernel*: + # pointer to the data of torch tensors can be retrieved with + # the `.data_ptr()` method + kernel = make_add_kernel(z.device) + kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), z.shape[0], grid=grid) + return z + +# Just like we standard PyTorch ops +# We use the `.apply` method to create a +# callable object for our function +add = _add.apply + +torch.manual_seed(0) +x = torch.rand(32, device='cuda') +y = torch.rand(32, device='cuda') +za = x + y +zb = add(x, y) +print(za) +print(zb) +print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}') + +th_ms = triton.testing.do_bench(lambda: x + y) +tr_ms = triton.testing.do_bench(lambda: add(x, y)) +print(th_ms, tr_ms) \ No newline at end of file