diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 6db128198..0d9c625f1 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -10,6 +10,7 @@ #include #include "triton/ir/builder.h" #include "triton/ir/metadata.h" +#include "triton/ir/context.h" namespace triton{ @@ -60,7 +61,7 @@ private: void push_function(function *fn) { functions_.push_back(fn); } public: - module(const std::string &name, context &ctx); + module(const std::string &name); context& get_context(); builder& get_builder(); // Setters @@ -94,7 +95,7 @@ public: private: std::string name_; - context &context_; + context context_; builder builder_; std::map values_; std::map types_; diff --git a/include/triton/runtime/arg.h b/include/triton/runtime/arg.h index acee97ae8..7ba2f63d3 100644 --- a/include/triton/runtime/arg.h +++ b/include/triton/runtime/arg.h @@ -5,6 +5,7 @@ #include #include +#include namespace triton{ namespace ir{ @@ -17,73 +18,8 @@ namespace driver{ namespace runtime { -enum arg_type { - INT1_T, - INT8_T, - INT16_T, - INT32_T, - INT64_T, - HALF_T, - FLOAT_T, - DOUBLE_T, - BUFFER_T -}; - -arg_type convert(ir::type *ty); -inline size_t size_of(arg_type ty){ - switch(ty){ - case INT1_T: return 1; - case INT8_T: return 1; - case INT16_T: return 2; - case INT32_T: return 4; - case INT64_T: return 8; - case HALF_T: return 2; - case FLOAT_T: return 4; - case DOUBLE_T: return 8; - case BUFFER_T: return 8; - default: throw std::runtime_error("unknown type"); - } -} - -inline bool is_int_type(arg_type ty){ - return ty == INT1_T || ty == INT8_T || ty == INT16_T || - ty == INT32_T || ty == INT64_T; -} - -class arg { -public: - union value_t { - bool int1; - int8_t int8; - int16_t int16; - int32_t int32; - int64_t int64; - uint16_t fp16; - float fp32; - double fp64; - driver::buffer* buf; - }; - -public: - // construct from primitive types - arg(arg_type ty, value_t val): ty_(ty) { val_ = val; } - arg(int32_t x): ty_(INT32_T) { val_.int32 = x; } - arg(int64_t x): ty_(INT64_T) { val_.int64 = x; } - arg(float x): ty_(FLOAT_T) { val_.fp32 = x; } - arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; } - arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; } - // accessors - arg_type type() const { return ty_; } - void* data() const { return (void*)&val_; } - driver::buffer* buffer() const { return val_.buf; } - - -private: - arg_type ty_; - value_t val_; -}; } } diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 501bddd39..ddfe78776 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -11,6 +11,7 @@ #include #include // codegen +#include "triton/ir/function.h" #include "triton/ir/context.h" #include "triton/runtime/arg.h" #include "triton/runtime/error.h" @@ -37,63 +38,86 @@ class context; namespace triton{ namespace runtime{ -typedef std::vector grid_t; -typedef std::map params_t; -template inline T convert(const std::string& name); -template<> inline long convert(const std::string& name) { return std::stol(name); } -template<> inline int convert(const std::string& name) { return std::stoi(name); } + +/* ------------------------- */ +/* Compilation options */ +/* ------------------------- */ + +struct options_t { + template + T D(const std::string& name) const { + return std::stoi(defines.at(name)); + } + std::unordered_map defines; + int num_warps; +}; + +/* ------------------------- */ +/* Runtime arguments */ +/* ------------------------- */ + +enum arg_type { + INT1_T, + INT8_T, + INT16_T, + INT32_T, + INT64_T, + HALF_T, + FLOAT_T, + DOUBLE_T, + BUFFER_T +}; + +inline size_t size_of(arg_type ty){ + switch(ty){ + case INT1_T : return 1; + case INT8_T : return 1; + case INT16_T : return 2; + case INT32_T : return 4; + case INT64_T : return 8; + case HALF_T : return 2; + case FLOAT_T : return 4; + case DOUBLE_T: return 8; + case BUFFER_T: return 8; + default: throw std::runtime_error("unknown type"); + } +} template void add_arg(std::stringstream& ss, T arg) { ss.write((char*)&arg, sizeof(T)); } + +/* ------------------------- */ +/* ------------------------- */ + enum asm_mode_t { ASM_LLIR, ASM_NV_PTX, ASM_NV_SASS }; -struct options_t { - template - T D(const std::string& name) const { - return convert(defines.at(name)); - } - std::unordered_map defines; - int num_warps; -}; - - -/* ------------------------- */ - class kernel{ -private: - static std::string preheader(); - static arg_type convert(ir::type *ty); +public: + typedef std::vector grid_t; public: - kernel(const std::string& src, const options_t& opt, driver::device *device); - void operator()(void* args, size_t args_size, driver::stream *stream, const std::vector& grid) const; - // getters - const std::vector& get_sig() const { return sig_; } - const std::vector& get_arg_names() const { return arg_names_; } - std::string get_asm(asm_mode_t mode); + static std::shared_ptr src_to_ir(const std::string& src, const options_t& opt); + static std::tuple, + std::shared_ptr, + size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt); -private: - void init_ir (const std::string &src); - void init_ker(); - void init_sig(); +public: + kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map &attrs = {}); + void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const; + std::string get_asm(asm_mode_t mode); public: const options_t opt; private: driver::device* dev_; - // signature - std::vector sig_; - std::vector arg_names_; - // triton context for parsing - ir::context ctx_; // handles std::shared_ptr ir_; std::shared_ptr mod_; @@ -102,36 +126,37 @@ private: size_t shared_mem_; }; +struct config { + std::map defines; + int num_warps; +}; + class function { public: - typedef std::function grid_fn_ty; + typedef std::function grid_fn_ty; typedef std::pair> kernel_pair_t; typedef std::map, kernel*> cache_t; - typedef std::vector, int>> autotune_vals_t; + typedef std::vector autotune_confs_t; -private: - static void do_loop_nest(std::vector const & ranges, - std::function const &)> const & f); public: function(const std::string& src, const options_t& opt, driver::device *device, - const autotune_vals_t& autotune_vals = {}, const std::vector &autotune_key = {}); - void operator()(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); - void operator()(void* args, size_t args_size, const grid_t& grid, driver::stream *stream); - // auto-tuning - cache_t::iterator find_in_cache(void* args, size_t args_size); - kernel* autotune(void* args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); - // getters - const std::vector get_kernels() { return kernels_; } + const std::vector& tune_confs = {}, const std::vector &tune_key = {}); + kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream); + void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream); + const std::vector get_signature() { return sig_; } private: - void init_kernels(const std::string& src, const options_t& opt, const autotune_vals_t& autotune_vals, driver::device *device); - -private: - std::vector kernels_; + std::map, std::vector>> kernels_; std::map, kernel*> cache_; + std::vector sig_; + std::vector align_idxs_; + std::vector int_idxs_; std::vector key_idxs_; std::vector arg_size_; std::vector arg_off_; + std::vector opts_; + std::string src_; + driver::device* device_; }; } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 28edb4e4f..d31f329c0 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -10,8 +10,8 @@ namespace triton{ namespace ir{ /* Module */ -module::module(const std::string &name, context &ctx) - : name_(name), context_(ctx), builder_(ctx) { +module::module(const std::string &name) + : name_(name), builder_(context_) { sealed_blocks_.insert(nullptr); } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 5e2caa062..69799e557 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -40,7 +40,6 @@ #include #include -std::mutex mut; namespace triton{ namespace runtime { @@ -49,22 +48,9 @@ namespace runtime { /* --------------------------------- */ /* --------------------------------- */ -arg_type kernel::convert(ir::type *ty) { - if(ty->is_integer_ty(1)) return INT1_T; - if(ty->is_integer_ty(8)) return INT8_T; - if(ty->is_integer_ty(16)) return INT16_T; - if(ty->is_integer_ty(32)) return INT32_T; - if(ty->is_integer_ty(64)) return INT64_T; - if(ty->is_half_ty()) return HALF_T; - if(ty->is_float_ty()) return FLOAT_T; - if(ty->is_double_ty()) return DOUBLE_T; - if(ty->is_pointer_ty()) return BUFFER_T; - throw std::runtime_error("unknown type"); -} - - -std::string kernel::preheader() { - return R"( +std::shared_ptr kernel::src_to_ir(const std::string& _src, const options_t& opt) { + std::string src = +R"( #define bool _Bool #define true 1 #define false 0 @@ -116,9 +102,7 @@ typedef short int16; typedef int int32; typedef long int64; )"; -} - -void kernel::init_ir(const std::string& src) { + src += _src; // pre-process TokenSequence tokens; Preprocessor cpp(&src, true); @@ -129,21 +113,21 @@ void kernel::init_ir(const std::string& src) { Parser parser(tokens); parser.Parse(); // ast -> triton-ir - ir::module* module = new ir::module("", ctx_); + auto ret = std::make_shared(""); Generator gen(&parser); - gen.Gen(module); - ir_.reset(module); + gen.Gen(&*ret); + return ret; } -void kernel::init_ker(){ - // triton-ir -> binary - std::unique_ptr bin; - std::unique_ptr target = dev_->make_target(); +std::tuple, + std::shared_ptr, + size_t> kernel::ir_to_bin(ir::module &ir, driver::device* dev, const options_t& opt) { // generate llvm code llvm::LLVMContext ctx; - std::string name = ir_->get_function_list()[0]->get_name(); + std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); // optimizations + std::unique_ptr target = dev->make_target(); bool cts_use_async = target->as_nvidia()->sm() >= 80; // create passes codegen::analysis::align align; @@ -162,73 +146,61 @@ void kernel::init_ker(){ codegen::transform::coalesce coalesce(&align, &layouts); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); // run passes - dce.run(*ir_); - pipeline.run(*ir_); - dce.run(*ir_); - disassociate.run(*ir_); - dce.run(*ir_); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - peephole.run(*ir_); - dce.run(*ir_); + dce.run(ir); + pipeline.run(ir); + dce.run(ir); + disassociate.run(ir); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + peephole.run(ir); + dce.run(ir); if(target->is_gpu()) - cts.run(*ir_); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - coalesce.run(*ir_); - dce.run(*ir_); - align.run(*ir_); - dce.run(*ir_); + cts.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + coalesce.run(ir); + dce.run(ir); + align.run(ir); + dce.run(ir); if(target->is_gpu()){ - reassociate.run(*ir_); - cts.run(*ir_); + reassociate.run(ir); + cts.run(ir); } - dce.run(*ir_); -// ir::print(*ir_, std::cout); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - peephole.run(*ir_); - dce.run(*ir_); - align.run(*ir_); - axes.run(*ir_); - layouts.run(*ir_); - swizzle.run(*ir_); - liveness.run(*ir_); - allocation.run(*ir_); - shared_mem_ = allocation.allocated_size(); -// if(allocation.allocated_size() > dev_->max_shared_memory()) -// throw exception::out_of_shared_memory(); - barriers.run(*ir_); - isel.visit(*ir_, *llvm); - //if(res->spilled() > 256) - // throw exception::out_of_registers(); - mod_.reset(driver::module::create(dev_, std::move(llvm))); - ker_.reset(driver::kernel::create(&*mod_, name.c_str())); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + peephole.run(ir); + dce.run(ir); + align.run(ir); + axes.run(ir); + layouts.run(ir); + swizzle.run(ir); + liveness.run(ir); + allocation.run(ir); + barriers.run(ir); + isel.visit(ir, *llvm); + std::shared_ptr mod(driver::module::create(dev, std::move(llvm))); + std::shared_ptr ker(driver::kernel::create(&*mod, name.c_str())); + size_t shared_mem = allocation.allocated_size(); + return std::make_tuple(mod, ker, shared_mem); } -void kernel::init_sig() { - ir::function* fn = ir_->get_function_list()[0]; - ir::function_type* ty = fn->get_fn_type(); - for(size_t i = 0; i < ty->get_num_params(); i++){ - sig_.push_back(convert(ty->get_param_ty(i))); - if(!fn->has_attr(i+1)) - continue; - } -} - -kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev): +kernel::kernel(const std::string& src, const options_t& opt, driver::device *dev, const std::map &attrs): opt(opt), dev_(dev) { - init_ir(preheader() + src); - init_ker(); - init_sig(); - for(auto arg: ir_->get_function_list()[0]->args()) - arg_names_.push_back(arg->get_name()); + // compile to Triton IR + ir_ = src_to_ir(src, opt); + // add attributes + for(const auto&x: attrs) + ir_->get_function_list()[0]->add_attr(x.first, x.second); + // compile to binary + std::tie(mod_, ker_, shared_mem_) = ir_to_bin(*ir_, dev, opt); } -void kernel::operator()(void *args, size_t args_size, driver::stream *stream, const std::vector& _grid) const{ +void kernel::operator()(const std::string& args, driver::stream *stream, const std::vector& _grid) const{ // set grid if(_grid.size() > 3) throw std::runtime_error("grid size must be no greater than 3"); @@ -236,7 +208,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; // enqueue - stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_); + stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, (void*)args.data(), args.size(), shared_mem_); } std::string kernel::get_asm(asm_mode_t mode) { @@ -282,124 +254,124 @@ std::string kernel::get_asm(asm_mode_t mode) { /* --------------------------------- */ /* --------------------------------- */ -void function::do_loop_nest(std::vector const & ranges, - std::function const &)> const & f){ - size_t D = ranges.size(); - std::vector values(D, 0); - size_t i = D - 1; - while(true){ - f(values); - while(values[i]++ == ranges[i] - 1){ - if(i == 0) - return; - values[i--] = 0; - } - i = D - 1; options_t opt; - - } -} -void function::init_kernels(const std::string& src, const options_t& opt, - const autotune_vals_t& confs, driver::device *device) { - // list of all possible configs - // just augment `opt` with each define of `confs` - // and override warp count - size_t num_opts = std::max(confs.size(), (size_t)1); - std::vector opts(num_opts, opt); - for(size_t i = 0; i < confs.size(); i++){ - opts[i].defines.insert(confs[i].first.begin(), confs[i].first.end()); - opts[i].num_warps = confs[i].second; - } - // compile all possible configs - // compilation errors (e.g., too much shared mem) - // will populate `err` - std::vector> err; - for(const options_t& opt: opts) { - try{ - kernels_.push_back({opt, std::make_shared(src, opt, device)}); - }catch(const exception::base& e){ - err.push_back({opt, e.what()}); - } - } - // throw an exception if `err` is not empty - if(kernels_.empty()){ - std::ostringstream dbg; - dbg << "Auto-Tuner could not find any valid configuration:" << std::endl; - for(auto x: err){ - dbg << "[ "; - dbg << x.first.num_warps << ", "; - dbg << "{ "; - for(const auto& y: x.first.defines) - dbg << '"' << y.first << "\"= \"" << y.second << "\", "; - dbg << " } ] -> " << x.second << std::endl; - } - throw exception::no_valid_configuration(dbg.str()); - } -} - -kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream* stream) { - // fast path -- no autotuning necessary - if(kernels_.size() == 1) - return &*kernels_.begin()->second; - // auto-tuning key - std::vector key(key_idxs_.size()); - for(size_t i = 0; i < key.size(); i++){ - int idx = key_idxs_[i]; - std::memcpy((void*)&key[i], (void*)((char*)args + arg_off_[idx]), arg_size_[idx]); - } - auto it = cache_.find(key); - if(it != cache_.end()) - return it->second; - // run auto-tuner - double best_ts = INFINITY; - kernel* ret = nullptr; - for(auto &x : kernels_){ - kernel* current = &*x.second; - auto grid = grid_fn(x.first); - while(grid.size() < 3) - grid.push_back(1); - double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); }, - stream, 5, 20); - ret = (ts < best_ts) ? current : ret; - best_ts = std::min(ts, best_ts); - } - stream->synchronize(); - it = cache_.insert({key, ret}).first; - return it->second; -} function::function(const std::string& src, const options_t &opt, driver::device *device, - const autotune_vals_t& autotune_vals, const std::vector& autotune_key) { - // pre-compile all kernels - init_kernels(src, opt, autotune_vals, device); - // find indices of autotune keys - auto arg_names = kernels_.at(0).second->get_arg_names(); - for(const std::string& name: autotune_key){ - auto it = std::find(arg_names.begin(), arg_names.end(), name); - if(it == arg_names.end()) - throw std::runtime_error(name + " is not a valid argument name"); - key_idxs_.push_back(std::distance(arg_names.begin(), it)); + const std::vector &tune_confs, const std::vector& tune_key) + : src_(src), device_(device) { + // kernel options + size_t num_opts = std::max(tune_confs.size(), (size_t)1); + opts_ = std::vector(num_opts, opt); + for(size_t i = 0; i < tune_confs.size(); i++){ + opts_[i].defines.insert(tune_confs[i].defines.begin(), tune_confs[i].defines.end()); + opts_[i].num_warps = tune_confs[i].num_warps; } + std::shared_ptr ir = kernel::src_to_ir(src, opts_[0]); + std::vector args = ir->get_function_list()[0]->args(); + // signature + auto convert = [](ir::type *ty) { + if(ty->is_integer_ty(1)) return INT1_T; + if(ty->is_integer_ty(8)) return INT8_T; + if(ty->is_integer_ty(16)) return INT16_T; + if(ty->is_integer_ty(32)) return INT32_T; + if(ty->is_integer_ty(64)) return INT64_T; + if(ty->is_half_ty()) return HALF_T; + if(ty->is_float_ty()) return FLOAT_T; + if(ty->is_double_ty()) return DOUBLE_T; + if(ty->is_pointer_ty()) return BUFFER_T; + throw std::runtime_error("unknown type"); + }; + for(ir::argument* arg: args) + sig_.push_back(convert(arg->get_type())); + // find indices of autotune keys + for(const std::string& name: tune_key){ + auto pred = [&](ir::argument* arg) { return arg->get_name() == name; }; + auto it = std::find_if(args.begin(), args.end(), pred); + if(it == args.end()) + throw std::runtime_error(name + " is not a valid argument name"); + key_idxs_.push_back(std::distance(args.begin(), it)); + } + // find indices of pointer + for(size_t i = 0; i < args.size(); i++) + if(args[i]->get_type()->is_pointer_ty() || + args[i]->get_type()->is_integer_ty()) + align_idxs_.push_back(i); // argument size and offset - auto tys = kernels_.at(0).second->get_sig(); size_t curr = 0; - for(arg_type ty: tys){ + for(arg_type ty: sig_){ arg_size_.push_back(size_of(ty)); arg_off_.push_back(curr); curr += arg_size_.back(); } - - } -void function::operator()(void* args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) { - runtime::kernel* fn = autotune(args, args_size, grid_fn, stream); - (*fn)(args, args_size, stream, grid_fn(fn->opt)); +uint64_t pow2_divisor(uint64_t N){ + if(N % 16 == 0) return 16; + if(N % 8 == 0) return 8; + if(N % 4 == 0) return 4; + if(N % 2 == 0) return 2; + return 1; } -void function::operator()(void* args, size_t args_size, const grid_t& grid, driver::stream* stream) { - return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream); +kernel* function::autotune(const std::string &args, const grid_fn_ty& grid_fn, driver::stream* stream) { + // align key + std::vector rt_key(align_idxs_.size(), 0); + for(size_t i = 0; i < align_idxs_.size(); i++){ + int idx = align_idxs_[i]; + uint64_t tmp = 0; + std::memcpy((void*)&tmp, (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]); + rt_key[i] = pow2_divisor(tmp); + } + // auto-tuning key + std::vector at_key(key_idxs_.size(), 0); + for(size_t i = 0; i < at_key.size(); i++){ + int idx = key_idxs_[i]; + std::memcpy((void*)&at_key[i], (void*)((char*)args.data() + arg_off_[idx]), arg_size_[idx]); + } + // cache key + std::vector cache_key; + cache_key.reserve(rt_key.size() + at_key.size()); + cache_key.insert(cache_key.end(), rt_key.begin(), rt_key.end()); + cache_key.insert(cache_key.end(), at_key.begin(), at_key.end()); + auto it = cache_.find(cache_key); + if(it != cache_.end()) + return it->second; + // compile kernels + if(kernels_.find(rt_key) == kernels_.end()){ + std::map attrs; + for(size_t i = 0; i < align_idxs_.size(); i++){ + bool is_ptr = sig_[align_idxs_[i]] == BUFFER_T; + attrs.insert({align_idxs_[i] + 1, ir::attribute(is_ptr ? ir::aligned : ir::multiple_of, rt_key[i])}); + } + for(const options_t& opt: opts_) + kernels_[rt_key].emplace_back(new kernel(src_, opt, device_, attrs)); + } + // run auto-tuner + double best_ts = INFINITY; + auto& kernels = kernels_.at(rt_key); + kernel* ret = nullptr; + if(kernels.size() == 1) + ret = &*kernels.back(); + else{ + for(auto ¤t : kernels_.at(rt_key)){ + auto grid = grid_fn(current->opt); + while(grid.size() < 3) + grid.push_back(1); + double ts = tools::bench([&]() { (*current)(args, stream, grid); }, + stream, 5, 20); + ret = (ts < best_ts) ? &*current : ret; + best_ts = std::min(ts, best_ts); + } + stream->synchronize(); + } + it = cache_.insert({cache_key, ret}).first; + return it->second; +} + +void function::operator()(const std::string& args, const grid_fn_ty& grid_fn, driver::stream *stream) { + runtime::kernel* fn = autotune(args, grid_fn, stream); + (*fn)(args, stream, grid_fn(fn->opt)); } diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 25da0f68e..b79030c40 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -1,12 +1,19 @@ import triton import torch +import os -# square benchmarks +def rounded_linspace(low, high, steps, div): + ret = torch.linspace(low, high, steps) + ret = (ret.int() + div - 1) // div * div + ret = torch.unique(ret) + return list(map(int, ret)) + +# Square benchmarks nt = {False: "n", True: "t"} square_confs = [ triton.testing.Benchmark( x_names=["M", "N", "K"], - x_vals=[512 * i for i in range(1, 16)], + x_vals=rounded_linspace(512, 8192, 17, 128), y_name="provider", y_vals=["torch", "triton", "cutlass"], y_lines=["Torch", "Triton", "CUTLASS"], @@ -17,16 +24,29 @@ square_confs = [ ) for AT in [False, True] for BT in [False, True] ] -@triton.testing.perf_report(square_confs) -def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): - import os +# Transformer training benchmarks +transformer_confs = [ + triton.testing.Benchmark( + x_names=[x], + x_vals = rounded_linspace(NK//16, NK, 33, 128), + y_name="provider", + y_vals=["torch", "triton", "cutlass"], + y_lines=["Torch", "Triton", "CUTLASS"], + ylabel="TFLOPS", + loglog=False, + plot_name=f"matmul-M{M}-{'NK'.replace(x, '')}{NK}", + args= {"M": M, 'NK'.replace(x,''): NK, "AT": False, "BT": False, "dtype": torch.float16} + ) for NK in [8192]\ + for i, x in enumerate(["N", "K"])\ + for M in [2048] +] +@triton.testing.perf_report(square_confs) +def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40): a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) - if AT: - a = a.t() - if BT: - b = b.t() + if AT: a = a.t() + if BT: b = b.t() num_flops = 2 * M * N * K if provider == "torch": torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) @@ -40,7 +60,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): import subprocess import tempfile import pandas as pd - # run program specified by CUTLASS_PROFILER env variable layout_a = "column" if AT else "row" layout_b = "column" if BT else "row" @@ -61,6 +80,7 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): f"--warmup-iterations={warmup}", f"--profiling-iterations={rep}", f"--output={fname}", + "--dist=uniform,min:0,max:1,scale:-1", "--verbose=false", ] # run cmd @@ -70,6 +90,3 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): cutlass_tflops = max(df_c["GFLOPs"]) / 1e3 return cutlass_tflops return None - -if __name__ == "__main__": - bench_op.run() diff --git a/python/bench/run.py b/python/bench/run.py index 17784947a..df52a07c4 100644 --- a/python/bench/run.py +++ b/python/bench/run.py @@ -38,4 +38,4 @@ def main(args): run_all(args.result_dir, args.with_plots, args.names) if __name__ == '__main__': - main(sys.argv[1:]) \ No newline at end of file + main(sys.argv[1:]) diff --git a/python/src/torch/utils.cc b/python/src/torch/utils.cc index 8deb8d4f2..c7bf74105 100644 --- a/python/src/torch/utils.cc +++ b/python/src/torch/utils.cc @@ -5,38 +5,16 @@ #include #include -std::map> tt_devices; -std::map> tt_streams; - namespace torch_utils { -void register_device(int64_t dev_id) { - if (tt_devices.find(dev_id) != tt_devices.end()) - return; - triton::driver::device *device; - if (dev_id >= 0) { - CUdevice handle; - triton::driver::dispatch::cuDeviceGet(&handle, dev_id); - device = new triton::driver::cu_device(handle, false); - } else - device = new triton::driver::host_device(); - tt_devices[dev_id].reset(device); +uint64_t cu_device(int64_t dev_id) { + CUdevice handle; + triton::driver::dispatch::cuDeviceGet(&handle, dev_id); + return (uint64_t)handle; } -void register_stream(int64_t dev_id) { - if (tt_streams.find(dev_id) != tt_streams.end()) - return; - triton::driver::stream *stream; - if (dev_id >= 0) { - CUstream handle = (CUstream)c10::cuda::getCurrentCUDAStream(dev_id).stream(); - stream = new triton::driver::cu_stream(handle, false); - } else - stream = new triton::driver::host_stream(); - tt_streams[dev_id].reset(stream); -} - -void synchronize(int64_t dev_id) { - tt_streams[dev_id]->synchronize(); +uint64_t cu_stream(int64_t dev_id) { + return (uint64_t)c10::cuda::getCurrentCUDAStream(dev_id).stream(); } void set_device(int64_t dev_id) { @@ -44,23 +22,11 @@ void set_device(int64_t dev_id) { C10_CUDA_CHECK(cudaSetDevice(dev_id)); } -torch::Tensor move_out_of_pool(torch::Tensor x) { - if (x.nbytes() == 0) - return torch::empty_like(x); - void *data; - cudaMalloc(&data, x.nbytes()); - auto ret = torch::from_blob((void *)data, x.sizes(), x.strides(), [data](void *ptr) { cudaFree(data); }, x.options()); - ret.copy_(x); - return ret; -} - } // namespace torch_utils void init_torch_utils(pybind11::module &m) { pybind11::module subm = m.def_submodule("torch_utils"); - subm.def("register_device", &torch_utils::register_device); - subm.def("register_stream", &torch_utils::register_stream); + subm.def("cu_device", &torch_utils::cu_device); + subm.def("cu_stream", &torch_utils::cu_stream); subm.def("set_device", &torch_utils::set_device); - subm.def("synchronize", &torch_utils::synchronize); - subm.def("move_out_of_pool", &torch_utils::move_out_of_pool); } \ No newline at end of file diff --git a/python/src/triton.cc b/python/src/triton.cc index c208e37a7..6eea1717e 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,10 +1,4 @@ #include "triton/driver/stream.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/lang/code_gen.h" -#include "triton/lang/cpp.h" -#include "triton/lang/parser.h" -#include "triton/runtime/arg.h" #include "triton/runtime/function.h" #include #include @@ -13,72 +7,22 @@ #include #include +namespace py = pybind11; + using namespace triton; namespace rt = triton::runtime; namespace drv = triton::driver; -namespace lng = triton::lang; -std::unordered_map opt_cache_; -std::map> id_fn_map; -extern std::map> tt_devices; -extern std::map> tt_streams; +/*****************************************************************************/ +/* Python bindings for triton::tools */ +/*****************************************************************************/ -/* Function utilities */ - -void register_fn(int op_id, int dev_id, - const std::string &src, const rt::options_t &opt, - const rt::function::autotune_vals_t &autotune_vals, - const std::vector &autotune_key) { - if (id_fn_map.find(op_id) == id_fn_map.end()) { - id_fn_map[op_id].reset(new rt::function(src, opt, &*tt_devices[dev_id], autotune_vals, autotune_key)); - } - for (const auto &k : id_fn_map[op_id]->get_kernels()) { - const rt::options_t *opt = &k.first; - pybind11::object obj = pybind11::cast(opt, pybind11::return_value_policy::reference); - for (auto x : opt->defines) - if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) - obj.attr(x.first.c_str()) = std::stoi(x.second); - opt_cache_[&k.second->opt] = obj; - } -} - -void delete_fn(int op_id) { - id_fn_map.erase(op_id); -} - -void cleanup() { - id_fn_map.clear(); - opt_cache_.clear(); -} - -size_t make_op_id() { - return id_fn_map.size(); -} - -std::vector get_fn_signature(size_t op_id) { - return id_fn_map[op_id]->get_kernels()[0].second->get_sig(); -} - -// Thanks to Scott Gray (OpenAI) for the idea to pass the arguments -// as a string constructed with struct.pack in python -void launch_kernel(int64_t op_id, int64_t dev_id, const std::string &args, size_t grid_0, size_t grid_1, size_t grid_2) { - rt::function *fn = id_fn_map.at(op_id).get(); - (*fn)((void **)args.c_str(), args.size(), {grid_0, grid_1, grid_2}, &*tt_streams[dev_id]); -} - -pybind11::object autotune(int64_t op_id, int64_t dev_id, const std::string &args, const rt::function::grid_fn_ty &grid) { - rt::function *fn = id_fn_map.at(op_id).get(); - auto wrapper = [&grid](const rt::options_t &opt) { - pybind11::object obj = pybind11::cast(&opt, pybind11::return_value_policy::reference); - for (auto x : opt.defines) - if (std::all_of(x.second.begin(), x.second.end(), ::isdigit)) - obj.attr(x.first.c_str()) = std::stoi(x.second); - return grid(*obj.cast()); - }; - rt::kernel *kernel = fn->autotune((void **)args.c_str(), args.size(), wrapper, &*tt_streams[dev_id]); - return opt_cache_.at(&kernel->opt); -} +/*! + @brief Function for extracting kernels out of a given source-string + This can be important to enable pre-processor macros (or tunable parameters) that should only + be defined within the scope of a single kernel function +*/ std::string extract_kernels(const std::string &str, const std::vector &names) { if (names.empty()) return str; @@ -94,50 +38,82 @@ std::string extract_kernels(const std::string &str, const std::vectorstr(1); kernels.push_back(std::make_tuple(name, pos, len)); } - + // check that all the kernels provided actually exist for (const std::string &name : names) { - // check that str matches any string in kernels using std::any_of auto pred = [&name](const std::tuple &t) { return std::get<0>(t) == name; }; bool found = std::any_of(kernels.begin(), kernels.end(), pred); if (!found) throw std::runtime_error("Unable to find kernel `" + name + "` in provided source code:\n" + str); } - - // extract functions + // simple parsing logic to extract the declaration and body of each specified kernel std::string ret; for (const auto &k : kernels) { std::string name; int pos, len; std::tie(name, pos, len) = k; - if (std::find(names.begin(), names.end(), name) != names.end()) { - std::string def = str.substr(pos, str.size() - pos); - int count, pos; - // skip over declaration - count = 1; - pos = def.find('('); - while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { - count += def[pos] == '('; - count -= def[pos] == ')'; - } - // skip over definition - count = 1; - pos = def.find('{', pos); - while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { - count += def[pos] == '{'; - count -= def[pos] == '}'; - } - ret += def.substr(0, pos); - ret += '\n'; + if (std::find(names.begin(), names.end(), name) == names.end()) + continue; + std::string def = str.substr(pos, str.size() - pos); + // skip over declaration + // by finding matching ')' for first '(' + int count = 1; + pos = def.find('('); + while (!(def[pos++] == ')' && count == 0) && pos < def.size()) { + count += def[pos] == '('; + count -= def[pos] == ')'; } + // skip over definition + // by finding matching '{' for first '}' + count = 1; + pos = def.find('{', pos); + while (!(def[pos++] == '}' && count == 0) && pos < def.size()) { + count += def[pos] == '{'; + count -= def[pos] == '}'; + } + ret += def.substr(0, pos); + ret += '\n'; } - return ret; } -void init_triton(pybind11::module &m) { - pybind11::module subm = m.def_submodule("triton"); - // bindings for triton classes - pybind11::enum_(subm, "arg_type") +void init_triton_tools(py::module &&m) { + m.def("extract_kernels", &extract_kernels); +} + +/*****************************************************************************/ +/* Python bindings for triton::driver */ +/*****************************************************************************/ + +void init_triton_driver(py::module &&m) { + // base device + py::class_(m, "device"); + // cuda device + py::class_(m, "cu_device") + .def(py::init()); + // host device + py::class_(m, "host_device") + .def(py::init<>()); + + // base stream + py::class_(m, "stream"); + // host stream + py::class_(m, "host_stream") + .def(py::init<>()); + // cuda stream + py::class_(m, "cu_stream") + // py doesn't support opaque pointer (e.g., CUstream) so + // we assume it has been converted to uint64_t + .def(py::init([](uint64_t handle, bool take_ownership) { + return std::unique_ptr(new driver::cu_stream((CUstream)handle, take_ownership)); + })); +} + +/*****************************************************************************/ +/* Python bindings for triton::runtime */ +/*****************************************************************************/ +void init_triton_runtime(py::module &&m) { + // argument type + py::enum_(m, "arg_type") .value("int1", rt::INT1_T) .value("int8", rt::INT8_T) .value("int16", rt::INT16_T) @@ -147,23 +123,38 @@ void init_triton(pybind11::module &m) { .value("float", rt::FLOAT_T) .value("double", rt::DOUBLE_T) .value("buffer", rt::BUFFER_T); - - pybind11::enum_(subm, "asm_mode") + // assembly mode + py::enum_(m, "asm_mode") .value("ptx", rt::ASM_NV_PTX) .value("sass", rt::ASM_NV_SASS); - - pybind11::class_(subm, "options", pybind11::dynamic_attr()) - .def(pybind11::init<>()) + // compilation options + py::class_(m, "options", py::dynamic_attr()) + .def(py::init<>()) .def_readwrite("defines", &rt::options_t::defines) - .def_readwrite("num_warps", &rt::options_t::num_warps); + .def_readwrite("num_warps", &rt::options_t::num_warps) + .def("__getattr__", [](rt::options_t *opt, const std::string &name) { + return opt->D(name); + }); + // kernel + py::class_(m, "kernel") + .def("__call__", &rt::kernel::operator()) + .def_readonly("opt", &rt::kernel::opt); + // tune conf + py::class_(m, "config") + .def(py::init, int>(), + py::arg("defines") = std::map(), + py::arg("num_warps")); - // hooks into triton constructs since frameworks may not use pybind11 - subm.def("extract_kernels", &extract_kernels); - subm.def("get_fn_signature", &get_fn_signature); - subm.def("register_fn", ®ister_fn); - subm.def("delete_fn", &delete_fn); - subm.def("make_op_id", &make_op_id); - subm.def("cleanup", &cleanup); - subm.def("autotune", &autotune, pybind11::return_value_policy::reference); - subm.def("launch_kernel", &launch_kernel); + // function + py::class_(m, "function") + .def(py::init &, const std::vector &>()) + .def("autotune", &rt::function::autotune, py::return_value_policy::reference_internal) + .def("signature", &rt::function::get_signature); +} + +void init_triton(py::module &m) { + py::module subm = m.def_submodule("triton"); + init_triton_driver(std::move(subm.def_submodule("driver"))); + init_triton_runtime(std::move(subm.def_submodule("runtime"))); + init_triton_tools(std::move(subm.def_submodule("tools"))); } diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index aaf738acc..a621d347d 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -50,8 +50,9 @@ import torch def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE): DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] torch.manual_seed(0) + defines = {"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)} triton.ops._matmul._kernels = dict() - triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)] + triton.ops._matmul._CONFIGS = [triton.config(defines=defines, num_warps=NWARP)] if M is None: M = TM if N is None: diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 206a2dd9f..1f83d2f9c 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -1,27 +1,21 @@ import os import struct from typing import Optional, Dict, List - import torch # C bindings import triton._C.libtriton.triton as _triton import triton._C.libtriton.torch_utils as _torch_utils -# Make sure internal C resources are cleaned up upon exit -import atexit - -@atexit.register -def cleanup(): - _triton.cleanup() codes = { - _triton.arg_type.int1: 'B', _triton.arg_type.int8: 'B', _triton.arg_type.int32: 'I', _triton.arg_type.int64: 'Q', - _triton.arg_type.half: 'H', _triton.arg_type.float: 'f', _triton.arg_type.double: 'd', _triton.arg_type.buffer: 'P' + _triton.runtime.arg_type.int1: 'B', _triton.runtime.arg_type.int8: 'B', _triton.runtime.arg_type.int32: 'I', + _triton.runtime.arg_type.int64: 'Q', _triton.runtime.arg_type.half: 'H', _triton.runtime.arg_type.float: 'f', + _triton.runtime.arg_type.double: 'd', _triton.runtime.arg_type.buffer: 'P' } def th_to_triton(obj): tys = { - torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long', torch.float16: 'half', - torch.float32: 'float', torch.float64: 'double' + torch.int8: 'char', torch.int16: 'short', torch.int32: 'int', torch.int64: 'long',\ + torch.float16: 'half', torch.float32: 'float', torch.float64: 'double' } if isinstance(obj, torch.dtype): return tys[obj] @@ -30,69 +24,54 @@ def th_to_triton(obj): def cdiv(a, b): return (a + b - 1) // b -def synchronize(device): - dev_id = device.index - dev_id = -1 if dev_id is None else dev_id - _torch_utils.synchronize(dev_id) - -def read(path, kernel_names:Optional[List]=None): +def read(path, kernel_names: Optional[List] = None): if kernel_names is None: kernel_names = [] with open(path, 'r') as f: source = f.read() - source = _triton.extract_kernels(source, kernel_names) + source = _triton.tools.extract_kernels(source, kernel_names) return source -class kernel: - def __init__(self, - src, - device, - defines: Optional[Dict]=None, - num_warps:int=4, - autotune_vals:Optional[List]=None, - autotune_key:Optional[List]=None): +config = _triton.runtime.config +class kernel: + def __init__(self, src, device, defines: Optional[Dict] = None, num_warps: int = 4, + autotune_vals: Optional[List] = None, autotune_key: Optional[List] = None): if defines is None: defines = {} if autotune_vals is None: autotune_vals = [] if autotune_key is None: autotune_key = [] - - # check if src is empty if src == '': raise ValueError('Kernel source code is empty') self.src = src - self.opt = _triton.options() - self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} - self.opt.num_warps = num_warps # device assert device.type in ['cuda', 'cpu'] if device.type == 'cuda': - self.device = torch.cuda.current_device() if device.index is None else device.index + self.device_id = torch.cuda.current_device() if device.index is None else device.index + self.device = _triton.driver.cu_device(_torch_utils.cu_device(self.device_id), False) + self.stream = _triton.driver.cu_stream(_torch_utils.cu_stream(self.device_id), False) if device.type == 'cpu': - self.device = -1 - _torch_utils.register_device(self.device) - _torch_utils.register_stream(self.device) - # C++ function wrapper - self.op_id = _triton.make_op_id() - _torch_utils.set_device(self.device) - _triton.register_fn(self.op_id, self.device, self.src, self.opt, autotune_vals, autotune_key) - # debug mode - self.is_debug = 'TRITON_DEBUG' in os.environ - # signature - arg_types = _triton.get_fn_signature(self.op_id) - self.tys = ''.join([codes[x] for x in arg_types]) + self.device_id = -1 + self.device = _triton.driver.host_device() + self.device = _triton.driver.host_stream() + _torch_utils.set_device(self.device_id) + # function + self.opt = _triton.runtime.options() + self.opt.defines = {k: th_to_triton(v) for k, v in defines.items()} + self.opt.num_warps = num_warps + # autotune_vals = [({}, 4)] + self.fn = _triton.runtime.function(self.src, self.opt, self.device, autotune_vals, autotune_key) + self.tys = ''.join([codes[x] for x in self.fn.signature()]) def __call__(self, *args, grid): - _torch_utils.set_device(self.device) + # make sure that the executing thread is on the right device + _torch_utils.set_device(self.device_id) # pack parameters into a byte buffer params = struct.pack(self.tys, *args) - opt = _triton.autotune(self.op_id, self.device, params, grid) + kernel = self.fn.autotune(params, grid, self.stream) # run kernel - grid = grid(opt) - grid_0 = grid[0] - grid_1 = 1 if len(grid) < 2 else grid[1] - grid_2 = 1 if len(grid) < 3 else grid[2] - _triton.launch_kernel(self.op_id, self.device, params, grid_0, grid_1, grid_2) + grid = grid(kernel.opt) + kernel(params, self.stream, grid) diff --git a/python/triton/ops/blocksparse/matmul.c b/python/triton/ops/blocksparse/matmul.c index a45d072b3..e3522ec29 100644 --- a/python/triton/ops/blocksparse/matmul.c +++ b/python/triton/ops/blocksparse/matmul.c @@ -1,17 +1,17 @@ -__global__ void NAME(TYPE *A __readonly __noalias __aligned(16), - TYPE *B __readonly __noalias __aligned(16), - TYPE *C __noalias __aligned(16), - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc __multipleof(8), - long stride_za __multipleof(8), - long stride_zb __multipleof(8), - long stride_zc __multipleof(8), - long stride_ha __multipleof(8), - long stride_hb __multipleof(8), - long stride_hc __multipleof(8), +__global__ void NAME(TYPE *A __readonly __noalias, + TYPE *B __readonly __noalias, + TYPE *C __noalias, + int lda, + int ldb, + int ldc, + long stride_za, + long stride_zb, + long stride_zc, + long stride_ha, + long stride_hb, + long stride_hc, int DS0, int DS1, - int SDD_K __multipleof(16), + int SDD_K, int SDD_off_width, int *lut, int *locks, int nlocks) { /* ---------------- */ diff --git a/python/triton/ops/blocksparse/softmax.c b/python/triton/ops/blocksparse/softmax.c index 625f4a6ac..8b8c9506a 100644 --- a/python/triton/ops/blocksparse/softmax.c +++ b/python/triton/ops/blocksparse/softmax.c @@ -1,17 +1,16 @@ -__global__ void forward(TYPE *X __readonly __noalias __aligned(16), +__global__ void forward(TYPE *X __readonly __noalias, float scale, - int *LUT __readonly __noalias __aligned(16), - TYPE *RPE __readonly __noalias __aligned(16), - TYPE *KP_M __readonly __noalias __aligned(16), - TYPE *ATTN_M __readonly __noalias __aligned(16), + int *LUT __readonly __noalias, + TYPE *RPE __readonly __noalias, + TYPE *KP_M __readonly __noalias, + TYPE *ATTN_M __readonly __noalias, int sizemax, - long stride_zx __multipleof(4), - long stride_zrpe __multipleof(BLOCK), - int stride_hrpe __multipleof(BLOCK), - int stride_srpe __multipleof(BLOCK), - int stride_zkpm __multipleof(BLOCK), - int stride_zattnm __multipleof(BLOCK)) -{ + long stride_zx, + long stride_zrpe, + int stride_hrpe, + int stride_srpe, + int stride_zkpm, + int stride_zattnm) { int pidhm = get_program_id(0); int pidz = get_program_id(1); // create index ranges @@ -97,14 +96,13 @@ __global__ void forward(TYPE *X __readonly __noalias __aligned(16), *? (check)px = y / ysum; } -__global__ void backward(TYPE *X __readonly __noalias __aligned(16), +__global__ void backward(TYPE *X __readonly __noalias, float scale, - TYPE *DX __readonly __noalias __aligned(16), + TYPE *DX __readonly __noalias, int *LUT, int sizemax, - long stride_zx __multipleof(BLOCK), - long stride_zdx __multipleof(BLOCK)) -{ + long stride_zx, + long stride_zdx) { int pidhm = get_program_id(0); int pidz = get_program_id(1); // create index ranges diff --git a/python/triton/ops/conv.c b/python/triton/ops/conv.c index d115ff540..f2c9e899a 100644 --- a/python/triton/ops/conv.c +++ b/python/triton/ops/conv.c @@ -1,126 +1,131 @@ -__global__ void conv(TYPE *A __noalias __readonly __aligned(16), - TYPE *B __noalias __readonly __aligned(16), - TYPE *C __noalias __aligned(16), - float alpha, - // equivalent matmul - int M, int N, int K, - // convolution properties - int pad_h, int pad_w, int stride_h, int stride_w, - // pointer increment - int *ADELTA, - // memory strides - int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8), - int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8), - int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) { - // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int ridz = get_program_id(2); - int gridx = M / TM; - int gridy = N / TN; - int rid = ridx + ridy * gridx; - ridx = rid / gridy; - ridy = rid % gridy; - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; - // reduction splitting - K = K / TZ; - int rk[TK] = ridz * K + 0 ... TK; +__global__ void conv(TYPE *A __noalias __readonly, + TYPE *B __noalias __readonly, + TYPE *C __noalias, + float alpha, + // equivalent matmul + int M, int N, int K, + // convolution properties + int pad_h, int pad_w, int stride_h, int stride_w, + // pointer increment + int *ADELTA, + // memory strides + int lda_z, int lda_ci, int lda_h, int lda_w, + int ldb_ci, int ldb_r, int ldb_s, int ldb_co, + int ldc_z, int ldc_co, int ldc_p, int ldc_q) { + // prologue + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int ridz = get_program_id(2); + int gridx = M / TM; + int gridy = N / TN; + int rid = ridx + ridy * gridx; + ridx = rid / gridy; + ridy = rid % gridy; + int rm[TM] = ridx * TM + 0 ... TM; + int rn[TN] = ridy * TN + 0 ... TN; + // reduction splitting + K = K / TZ; + int rk[TK] = ridz * K + 0 ... TK; - // unpack aggregate rows - // m = (z, p, q) - int rq[TM] = rm % QQ; - int rzp[TM] = rm / QQ; - int rp[TM] = rzp % PP; - int rz[TM] = rzp / PP; - // unpack aggregate reduction - // k = (ci, r, s) - int rs [TK] = rk % SS; - int rcir[TK] = rk / SS; - int rr [TK] = rcir % RR; - int rci [TK] = rcir / RR; + // unpack aggregate rows + // m = (z, p, q) + int rq[TM] = rm % QQ; + int rzp[TM] = rm / QQ; + int rp[TM] = rzp % PP; + int rz[TM] = rzp / PP; + // unpack aggregate reduction + // k = (ci, r, s) + int rs[TK] = rk % SS; + int rcir[TK] = rk / SS; + int rr[TK] = rcir % RR; + int rci[TK] = rcir / RR; - // padding / striding - int rh_0[TM] = rp * stride_h - pad_h; - int rw_0[TM] = rq * stride_w - pad_w; - int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :]; - int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :]; + // padding / striding + int rh_0[TM] = rp * stride_h - pad_h; + int rw_0[TM] = rq * stride_w - pad_w; + int rh[TM, TK] = rh_0[:, newaxis] + rr [newaxis, :]; + int rw[TM, TK] = rw_0[:, newaxis] + rs [newaxis, :]; - // pointers to lhs - int offa[TM, TK] = rz [:, newaxis] * lda_z + - rci[newaxis, :] * lda_ci + - rh * lda_h + - rw * 1; - TYPE* pa[TM, TK] = A + offa; - int* padelta[TK] = ADELTA + rk; - // pointers to rhs - int offb[TK, TN] = rci[:, newaxis] * ldb_ci + - rr [:, newaxis] * ldb_r + - rs [:, newaxis] * ldb_s + - rn [newaxis, :] * 1; - TYPE* pb[TK, TN] = B + offb; + // pointers to lhs + int offa[TM, TK] = rz[:, newaxis] * lda_z + + rci [newaxis, :] * lda_ci + + rh * lda_h + + rw * 1; + TYPE *pa[TM, TK] = A + offa; + int *padelta[TK] = ADELTA + rk; + // pointers to rhs + int offb[TK, TN] = rci[:, newaxis] * ldb_ci + + rr + [:, newaxis] * ldb_r + + rs + [:, newaxis] * ldb_s + + rn [newaxis, :] * 1; + TYPE *pb[TK, TN] = B + offb; - // prefetches operands - bool checkam[TM, TK] = rm[:, newaxis] < M; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - bool checkb[TK, TN] = rk[:, newaxis] < K; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - int total = 0; + // prefetches operands + bool checkam[TM, TK] = rm[:, newaxis] < M; + bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; + bool checkb[TK, TN] = rk[:, newaxis] < K; + TYPE a[TM, TK] = checka ? *pa : 0; + TYPE b[TK, TN] = checkb ? *pb : 0; + int total = 0; - // reduction loop - float acc[TM, TN] = 0; - for(int k = K; k > 0; k -= TK){ - acc += a @ b; - // increment A - int adelta[TK] = *padelta; - padelta += TK; - pa += adelta[newaxis, :]; - // bounds-checking A - rk += TK; - rs = rk % SS; - rcir = rk / SS; - rr = rcir % RR; - rh = rh_0[:, newaxis] + rr[newaxis, :]; - rw = rw_0[:, newaxis] + rs[newaxis, :]; - bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; - // increment B - pb += TK * ldb_s; - // bounds-checking B - bool checkb[TK, TN] = k > TK; - a = checka ? *pa : 0; - b = *?(checkb)pb; - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; + // reduction loop + float acc[TM, TN] = 0; + for (int k = K; k > 0; k -= TK) { + acc += a @b; + // increment A + int adelta[TK] = *padelta; + padelta += TK; + pa += adelta [newaxis, :]; + // bounds-checking A + rk += TK; + rs = rk % SS; + rcir = rk / SS; + rr = rcir % RR; + rh = rh_0[:, newaxis] + rr [newaxis, :]; + rw = rw_0[:, newaxis] + rs [newaxis, :]; + bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; + // increment B + pb += TK * ldb_s; + // bounds-checking B + bool checkb[TK, TN] = k > TK; + a = checka ? *pa : 0; + b = *? (checkb)pb; + } + acc = acc * alpha; + TYPE c[TM, TN] = acc; - // epilogue - rm = ridx * TM + 0 ... TM; - rn = ridy * TN + 0 ... TN; - rq = rm % QQ; - rzp = rm / QQ; - rp = rzp % PP; - rz = rzp / PP; - int offc[TM, TN] = rz [:, newaxis] * ldc_z + - rn [newaxis, :] * ldc_co+ - rp [:, newaxis] * ldc_p + - rq [:, newaxis] * 1; - TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N; + // epilogue + rm = ridx * TM + 0 ... TM; + rn = ridy * TN + 0 ... TN; + rq = rm % QQ; + rzp = rm / QQ; + rp = rzp % PP; + rz = rzp / PP; + int offc[TM, TN] = rz[:, newaxis] * ldc_z + + rn [newaxis, :] * ldc_co + + rp + [:, newaxis] * ldc_p + + rq + [:, newaxis] * 1; + TYPE *pc[TM, TN] = C + offc; + bool checkc[TM, TN] = rm[:, newaxis] < M && rn [newaxis, :] < N; -#if (TZ==1) - *?(checkc) pc = c; +#if (TZ == 1) + *? (checkc)pc = c; #else - // accumulate partial result using spin-locks - int *plock = locks + rid; - int *pcount = plock + get_num_programs(0) * get_num_programs(1); - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % TZ); - atomic_xchg(plock, 0); + // accumulate partial result using spin-locks + int *plock = locks + rid; + int *pcount = plock + get_num_programs(0) * get_num_programs(1); + for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)) + ; + int count = *pcount; + if (count == 0) + *? (checkc)pc = c; + else + *? (checkc)pc = c + *? (checkc)pc; + atomic_xchg(pcount, (count + 1) % TZ); + atomic_xchg(plock, 0); #endif } \ No newline at end of file diff --git a/python/triton/ops/cross_entropy.c b/python/triton/ops/cross_entropy.c index 2767fae1a..2de793448 100644 --- a/python/triton/ops/cross_entropy.c +++ b/python/triton/ops/cross_entropy.c @@ -1,8 +1,4 @@ -__global__ void forward(TYPE *logit __aligned(16), - TYPE *modified_logit __aligned(16), - long *indices __readonly, - TYPE *result __aligned(16), - int n_cols __multipleof(N_COLS_MULT)) { +__global__ void forward(TYPE *logit, TYPE *modified_logit, long *indices, TYPE *result, int n_cols) { int row = get_program_id(0); bool check[TILE] = ((0 ... TILE) < n_cols); @@ -19,10 +15,7 @@ __global__ void forward(TYPE *logit __aligned(16), *(result + row) = *(modified_logit + (local_ind + n_cols * row)); } -__global__ void backward(TYPE *neg_logprobs __aligned(16), - long *indices __aligned(16), - TYPE *dneg_logprobs __aligned(16), - int n_cols __multipleof(N_COLS_MULT)) { +__global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs, int n_cols) { int row = get_program_id(0); // pointer arithmetic diff --git a/python/triton/ops/matmul.c b/python/triton/ops/matmul.c index 3410649c7..747340bea 100644 --- a/python/triton/ops/matmul.c +++ b/python/triton/ops/matmul.c @@ -1,16 +1,12 @@ #define STM 8 #define STN 8 -__global__ void matmul(TYPE *A __noalias __readonly __aligned(16), - TYPE *B __noalias __readonly __aligned(16), - TYPE *C __noalias __aligned(16), +__global__ void matmul(TYPE *A __noalias __readonly, + TYPE *B __noalias __readonly, + TYPE *C __noalias, float alpha, - int M, - int N, - int K __multipleof(16), - int lda __multipleof(LDA_POW2_DIV), - int ldb __multipleof(LDB_POW2_DIV), - int ldc __multipleof(LDC_POW2_DIV), + int M, int N, int K, + int lda, int ldb, int ldc, int *locks) { // prologue int pid = get_program_id(0); diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 6033ca780..a8925e34b 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -6,18 +6,18 @@ class _matmul(torch.autograd.Function): src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c")) _DEFAULT_CONFIGS = [ - ({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4), - ({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4), - ({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4), - ({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4), - ({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4), - ({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4), - ({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2), - ({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2), - # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4), - # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4), - # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4), - # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4), + triton.config(defines={"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, num_warps=4), + triton.config(defines={'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=4), + triton.config(defines={'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, num_warps=2), + triton.config(defines={'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, num_warps=2), + triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4), + triton.config(defines={'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4), + triton.config(defines={'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, num_warps=4), ] _CONFIGS = _DEFAULT_CONFIGS diff --git a/python/triton/testing.py b/python/triton/testing.py index 0148eb6b9..510375b08 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,4 +1,5 @@ import torch +import os def sparsify_tensor(x, mask, block): ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) @@ -77,8 +78,12 @@ class Mark: df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv")) def run(self, result_path, with_plot): - for bench in self.benchmarks: - self._run(bench, result_path, with_plot) + with open(os.path.join(result_path, "results.html"), "w") as html: + html.write("\n") + for bench in self.benchmarks: + self._run(bench, result_path, with_plot) + html.write(f"\n") + html.write("\n") def perf_report(benchmarks): wrapper = lambda fn: Mark(fn, benchmarks) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py new file mode 100644 index 000000000..9163e4efd --- /dev/null +++ b/python/tutorials/01-vector-add.py @@ -0,0 +1,76 @@ +import torch +import triton + +# source-code for Triton compute kernel +# here we just copy-paste the above code without the extensive comments. +# you may prefer to store it in a .c file and load it from there instead. +_src = """ +__global__ void add(float* z, float* x, float* y, int N){ + // program id + int pid = get_program_id(0); + // create arrays of pointers + int offset[BLOCK] = pid * BLOCK + 0 ... BLOCK; + float* pz[BLOCK] = z + offset; + float* px[BLOCK] = x + offset; + float* py[BLOCK] = y + offset; + // bounds checking + bool check[BLOCK] = offset < N; + // write-back + *?(check)pz = *?(check)px + *?(check)py; +} + """ +# This function returns a callable `triton.kernel` object +# created from the above source code. +# For portability, we maintain a cache of kernels for different `torch.device` +# We compile the kernel with -DBLOCK=1024 +_kernels = dict() + +def make_add_kernel(device): + if device not in _kernels: + defines = {'BLOCK': 1024} + autotune_vals = [({'BLOCK': '1024'}, 4), ({'BLOCK': '2048'}, 4)] + autotune_key = ["N"] + _kernels[device] = triton.kernel(_src, device=device, defines=defines, autotune_vals=autotune_vals, + autotune_key=autotune_key) + return _kernels[device] + +# This is a standard torch custom autograd Function +# The only difference is that we can now use the above kernel +# in the `forward` and `backward` functions.` +class _add(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + # constraints of the op + assert x.dtype == torch.float32 + # *allocate output* + z = torch.empty_like(x) + # *create launch grid*: + # this is a function which takes compilation parameters `opt` + # as input and returns a tuple of int (i.e., launch grid) for the kernel. + # triton.cdiv is a shortcut for ceil division: + # triton.cdiv(a, b) = (a + b - 1) // b + grid = lambda opt: (triton.cdiv(z.shape[0], opt.BLOCK), ) + # *launch kernel*: + # pointer to the data of torch tensors can be retrieved with + # the `.data_ptr()` method + kernel = make_add_kernel(z.device) + kernel(z.data_ptr(), x.data_ptr(), y.data_ptr(), z.shape[0], grid=grid) + return z + +# Just like we standard PyTorch ops +# We use the `.apply` method to create a +# callable object for our function +add = _add.apply + +torch.manual_seed(0) +x = torch.rand(32, device='cuda') +y = torch.rand(32, device='cuda') +za = x + y +zb = add(x, y) +print(za) +print(zb) +print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}') + +th_ms = triton.testing.do_bench(lambda: x + y) +tr_ms = triton.testing.do_bench(lambda: add(x, y)) +print(th_ms, tr_ms) \ No newline at end of file