diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index 1c2878d64..2b3054f5b 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -2,6 +2,9 @@ #define TDL_INCLUDE_CODEGEN_BARRIERS_H #include +#include +#include +#include namespace triton { @@ -9,6 +12,7 @@ namespace ir { class module; class basic_block; class instruction; + class masked_load_async_inst; class value; class builder; } @@ -29,18 +33,15 @@ namespace transform{ class membar { private: typedef std::pair interval_t; - typedef std::vector interval_vec_t; + typedef std::set val_set_t; + typedef std::vector val_vec_t; private: - interval_vec_t join(const std::vector& intervals); - void insert_barrier(ir::instruction *instr, std::pair type, ir::builder &builder); - bool intersect(const interval_vec_t &X, interval_t x); - bool intersect(const interval_vec_t &X, const interval_vec_t &Y); - void add_reference(ir::value *v, interval_vec_t &res); - void get_read_intervals(ir::instruction *i, interval_vec_t &res); - void get_written_intervals(ir::instruction *i, interval_vec_t &res); - std::pair transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, - std::map > &insert_loc, std::set &safe_war, std::vector &to_sync); + bool intersect(const val_set_t &X, const val_set_t &Y); + int group_of(triton::ir::value *i, std::vector &async_write); + val_set_t intersect_with(const val_set_t& as, const val_set_t& bs); + void transfer(ir::basic_block *block, val_vec_t &async_write, val_set_t &sync_write, val_set_t &sync_read, + std::set &safe_war, bool &inserted, ir::builder &builder); public: membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc): diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index 19dba70a9..d8a21e6cc 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -16,6 +16,10 @@ namespace ir { } namespace codegen{ +namespace analysis{ +class layouts; +} + namespace transform{ class peephole { @@ -33,11 +37,12 @@ private: private: public: - peephole(target* tgt): tgt_(tgt) {} + peephole(target* tgt, analysis::layouts* layouts): tgt_(tgt), layouts_(layouts) {} void run(ir::module &mod); private: target* tgt_; + analysis::layouts* layouts_; }; diff --git a/include/triton/codegen/transform/pipeline.h b/include/triton/codegen/transform/pipeline.h new file mode 100644 index 000000000..4d0650529 --- /dev/null +++ b/include/triton/codegen/transform/pipeline.h @@ -0,0 +1,28 @@ +#ifndef TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H +#define TRITON_INCLUDE_IR_CODEGEN_PIPELINE_H + +// forward declaration +namespace triton { +namespace ir { +class module; +} +} // namespace triton + +namespace triton { +namespace codegen { +namespace transform { + +class pipeline { +public: + pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {} + void run(ir::module &module); + +private: + bool has_copy_async_; +}; + +} // namespace transform +} // namespace codegen +} // namespace triton + +#endif diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index 6184e7364..b5915a365 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -29,7 +29,7 @@ public: static driver::stream* create(backend_t backend); // methods virtual void synchronize() = 0; - virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args = NULL, size_t args_size = 0) = 0; + virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size, size_t shared_mem = 0) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; // template helpers @@ -44,7 +44,7 @@ class host_stream: public stream { public: host_stream(); void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size, size_t shared_mem); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; @@ -55,7 +55,7 @@ public: cu_stream(CUstream str, bool take_ownership); cu_stream(); void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size, size_t shared_mem); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 0a5c2c8d6..99c639cc0 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -35,6 +35,7 @@ public: basic_block* get_insert_block() { return block_; } iterator get_insert_point() { return insert_point_;} // Constants + value *get_int1(bool val); value *get_int32(int32_t val); value *get_int64(int64_t val); // Types @@ -149,7 +150,7 @@ public: value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = ""); value *create_copy_from_shared(value *arg, const std::string &name = ""); value *create_barrier(const std::string &name = ""); - value *create_async_wait(); + value *create_async_wait(int N); private: context &ctx_; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 7eebe8755..4718e7d9f 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -92,6 +92,7 @@ private: public: void set_incoming_value(unsigned i, value *v); void set_incoming_block(unsigned i, basic_block *block); + value *get_value_for_block(basic_block *block); value *get_incoming_value(unsigned i) { return get_operand(i); } basic_block *get_incoming_block(unsigned i) { return blocks_[i]; } unsigned get_num_incoming() { return get_num_operands(); } @@ -803,14 +804,18 @@ public: class async_wait_inst: public instruction{ private: - async_wait_inst(context &ctx, const std::string &name, instruction *next); - std::string repr_impl() const { return "async_wait"; } + async_wait_inst(context &ctx, int N, const std::string &name, instruction *next); + std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; } _TRITON_DEFINE_CLONE(async_wait_inst) _TRITON_DEFINE_ACCEPT(async_wait_inst) public: - static async_wait_inst* create(context &ctx, const std::string &name = "", - instruction *next = nullptr); + static async_wait_inst* create(context &ctx, int N, + const std::string &name = "", instruction *next = nullptr); + int get_N() { return N_; } + +private: + int N_; }; // On NVIDIA, implementation is such that diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index bc21059c4..501bddd39 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -98,6 +98,8 @@ private: std::shared_ptr ir_; std::shared_ptr mod_; std::shared_ptr ker_; + // shared mem + size_t shared_mem_; }; class function { diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 58f8acecd..c0dbd5061 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -30,11 +30,8 @@ private: high_resolution_clock::time_point _start; }; -inline double bench(std::function const & op, driver::stream * stream, bool normalize = false) +inline double bench(std::function const & op, driver::stream * stream, size_t warmup = 10, size_t repeat = 200) { -// const driver::device * device = stream->context()->device(); - size_t warmup = 10; - size_t repeat = 50; timer tmr; std::vector times; double total_time = 0; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index a40257bdb..ed1d5c881 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -312,7 +312,6 @@ std::vector align::populate_max_contiguous_gep(ir::getelementptr_inst* if(rhs_cst_info[d].num_cst) rvalue = lhs_max_contiguous[d]; result[d] = std::max(lvalue, rvalue); -// std::cout << "max contiguous: " << x->get_name() << " " << d << " " << result[d] << std::endl; } return add_to_cache(x, result, max_contiguous_); } @@ -527,8 +526,7 @@ void align::run(ir::module &mod) { ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); // ir::for_each_value(mod, [this](ir::value* v) { // if(dynamic_cast(v) || dynamic_cast(v)) -// std::cout << "ALIGN: " << v->get_name() << " " << starting_multiple_.at(v)[0] << " " << max_contiguous_.at(v)[0] -// << " " << starting_multiple_.at(v)[1] << " " << max_contiguous_.at(v)[1] << std::endl; +// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl; // }); } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 77e39e4f5..cd27da12f 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -118,15 +118,6 @@ data_layout::data_layout(id_t id, // std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; // std::cout << order_[0] << " " << order_[1] << std::endl; } - if(is_recoalesce){ - if(ptr.size() > 0){ -// std::cout << "recoalesce: " << order_[0] << " " << order_[1] << " " << ptr.size() << std::endl; -// std::cout << max_contiguous[0] << " " << max_contiguous[1] << std::endl; -// if(order_[0] == 0) -// exit(1); - } - } -// std::cout << "---" << std::endl; } int data_layout::find_axis(int to_find) const { @@ -213,14 +204,16 @@ scanline_layout::scanline_layout(size_t num_warps, ir::value *ptr = nullptr; for(ir::value *v: values) for(ir::user *usr: v->get_users()) - if(auto *st = dynamic_cast(usr)) - ptr = st->get_pointer_operand(); + if(auto *io = dynamic_cast(usr)){ + if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank()) + ptr = io->get_pointer_operand(); + } unsigned i = order_[0]; int contiguous = 1; if(ptr){ int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits(); - contiguous = std::min(align->contiguous(ptr)[i], 128 / nbits); + contiguous = std::min(align->get(ptr, i), 128 / nbits); } nts_[i] = clamp(size / num_threads, 1, std::min(contiguous, shape_[i])); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 8c900df70..d8150a60d 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { } void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ - unsigned vector = 1; - ir::value *ptrs = x->get_pointer_operand(); - ir::value *msks = x->get_mask_operand(); + unsigned in_vec = 1; + ir::value *arg = x->get_pointer_operand(); analysis::shared_layout* out_layout = layouts_->get(x)->to_shared(); - analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline(); + analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); auto out_order = out_layout->get_order(); auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) - vector = in_layout->nts(in_order[0]); + in_vec = in_layout->nts(in_order[0]); + int out_vec = swizzle_->get_vec(out_layout); + int min_vec = std::min(out_vec, in_vec); + int s = std::max(out_vec / in_vec, 1); // - int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; - int num_per_phase = std::max(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1); - Value *max_phase = i32(8 / num_per_phase); + int per_phase = swizzle_->get_per_phase(out_layout); + int max_phase = swizzle_->get_max_phase(out_layout); // + int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); + int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); + int n_shared_0 = std::max(in_vec / out_vec, 1); auto shapes = x->get_type()->get_tile_shapes(); - // - int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); - int n_shared = std::max(8 / in_layout->mts(in_order[1]), 1); - std::vector shared; - for(size_t i = 0; i < n_shared; i++){ - indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld); - // phase - Value* phase = udiv(idx[in_order[1]], i32(num_per_phase)); - phase = urem(phase, max_phase); - // off - Value* off_0 = idx[in_order[0]]; - off_0 = udiv(off_0, i32(vector)); - off_0 = xor_(off_0, phase); - off_0 = mul(off_0 , i32(vector)); - Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); - Value* off = add(off_0, off_1); - // - shared.push_back(gep(shmems_[x], {off})); - } - // - for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){ - auto idx = idxs_[ptrs][i]; + BasicBlock* CurrBB = builder_->GetInsertBlock(); + BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); + std::map, Value*> tmp; + std::vector> shared; + for(int i = 0; i < idxs_.at(arg).size(); i++){ + unsigned id = i / min_vec; // input ptr info - GetElementPtrInst *in_gep = dyn_cast(vals_[ptrs][idx]); - Value *in_base = in_gep->getPointerOperand(); - size_t in_off = dyn_cast(in_gep->idx_begin())->getValue().getSExtValue()*2*vector; - Value* out_base = shared[(i / per_thread_ld) % n_shared]; - int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]); - int out_off_1 = i % per_thread_ld; - int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2; - // asm - FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false); - std::string mod = (vector*2 == 16) ? ".cg" : ".ca"; - std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";"; - InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true); - call(iasm, {vals_[msks][idx], out_base, in_base}); + int id_0 = id % (in_ld/min_vec); + int id_1 = id / (in_ld/min_vec); + int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); + int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); + int off = (off_1*shapes[in_order[0]] + off_0); + std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; + if(tmp.find(key) == tmp.end()){ + if(CurrBB != FirstBB) + builder_->SetInsertPoint(FirstBB->getTerminator()); + indices_t idx = idxs_.at(arg).at(key.first*in_ld); + Value* phase = udiv(idx[in_order[1]], i32(per_phase)); + phase = urem(phase, i32(max_phase)); + Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); + Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); + off_0 = udiv(off_0, i32(min_vec)); + off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); + off_0 = mul(off_0 , i32(min_vec)); + Value* off = add(off_0, off_1); + if(CurrBB != FirstBB) + builder_->SetInsertPoint(CurrBB); + tmp[key] = gep(shmems_[x], {off}); + } + shared.push_back({tmp[key], off}); } + + for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){ + auto idx = idxs_[arg][i]; + // input ptr info + GetElementPtrInst *in_gep = dyn_cast(vals_[arg][idx]); + Value *in_base = in_gep->getPointerOperand(); + ConstantInt* cst = dyn_cast(in_gep->idx_begin()); + size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0; + in_base = cst ? in_base : in_gep; + // output ptr info + Value* out_base = shared[i].first; + int out_off = shared[i].second*2; + // asm + FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false); + std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca"; + std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";"; + InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true); + call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base}); + } + + std::string asm_str = "cp.async.commit_group;"; + InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); + call(iasm); } void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { @@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); auto shapes = cts->get_type()->get_tile_shapes(); - // default implementation + // store to shared Value *current = nullptr; std::map, Value*> ptrs; for(int i = 0; i < idxs_.at(arg).size(); i++){ @@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) { add_barrier(); } -void generator::visit_async_wait_inst(ir::async_wait_inst*) { - std::string asm_str = "cp.async.wait_all;"; +void generator::visit_async_wait_inst(ir::async_wait_inst* i) { + std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";"; InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); call(iasm); - add_barrier(); } void generator::visit_make_range_dyn(ir::make_range_dyn* x) { @@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) { if(unsigned alloc_size = alloc_->allocated_size()){ Type *int_8_ty = Type::getInt8Ty(*ctx_); Type *int_32_ty = Type::getInt32Ty(*ctx_); - ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4); + ArrayType *array_ty = ArrayType::get(int_32_ty, 0); Type *ptr_ty = ptr_ty(int_8_ty, 3); GlobalVariable *sh_mem_array = - new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage, + new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage, nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); shmem_ = bit_cast(sh_mem_array, ptr_ty); } diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 2972ed6ca..bea371c44 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -15,114 +15,105 @@ namespace triton { namespace codegen{ namespace transform{ -bool membar::intersect(const interval_vec_t &X, interval_t x) { - return std::any_of(X.begin(), X.end(), [&](const interval_t &y){ - bool left_intersect = y.first <= x.first && x.first < y.second; - bool right_intersect = y.first <= x.second && x.second < y.second; - return left_intersect || right_intersect; - }); -} -bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) { - return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){ - return intersect(X, y); - }); -} -void membar::add_reference(ir::value *v, interval_vec_t &res){ - auto *i = dynamic_cast(v); - if(!i) - return; - if(!i->get_type()->is_tile_ty()) - return; - analysis::shared_layout* layout = layouts_->get(v)->to_shared(); - if(!layout) - return; - if(alloc_->has_offset(layout)){ - unsigned offset = alloc_->offset(layout); - res.push_back(interval_t(offset, offset + layout->get_size())); +int membar::group_of(ir::value* v, std::vector &async_write) { + if(ir::phi_node* phi = dynamic_cast(v)){ + analysis::shared_layout* layout = layouts_->get(v)->to_shared(); + analysis::double_buffer_info_t* info = layout->get_double_buffer(); + if(info) + return group_of(info->first, async_write); + std::vector groups(phi->get_num_operands()); + std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);}); + return *std::max_element(groups.begin(), groups.end()); + } + else{ + auto it = std::find(async_write.begin(), async_write.end(), v); + return std::distance(async_write.begin(), it); } } -void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){ - for(ir::value *op: i->ops()) - add_reference(op, res); + +membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& bs) { + val_set_t ret; + for(ir::value* a: as){ + if(!a->get_type()->is_tile_ty()) + continue; + analysis::shared_layout* a_layout = layouts_->get(a)->to_shared(); + if(!a_layout) + continue; + int a_start = alloc_->offset(a_layout); + int a_end = a_start + a_layout->get_size(); + for(ir::value* b: bs){ + if(!b->get_type()->is_tile_ty()) + continue; + analysis::shared_layout* b_layout = layouts_->get(b)->to_shared(); + if(!b_layout) + continue; + int b_start = alloc_->offset(b_layout); + int b_end = b_start + b_layout->get_size(); + if(a_start < b_end || b_start < a_end) + ret.insert(b); + } + } + return ret; } -void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){ - if(!dynamic_cast(i) && !dynamic_cast(i)) - add_reference(i, res); -} - -void membar::insert_barrier(ir::instruction *instr, std::pair type, ir::builder &builder) { - if(auto *phi = dynamic_cast(instr)) { - std::set incoming; - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::instruction *inc_val = dynamic_cast(phi->get_incoming_value(n)); - assert(inc_val); - if(incoming.insert(inc_val).second){ - ir::basic_block *block = inc_val->get_parent(); - builder.set_insert_point(block->get_inst_list().back()); - if(type.first) - builder.create_async_wait(); - if(type.second) - builder.create_barrier(); +void membar::transfer(ir::basic_block *block, + val_vec_t& async_write, + val_set_t& sync_write, + val_set_t& sync_read, + std::set& safe_war, + bool& inserted, ir::builder& builder) { + ir::basic_block::inst_list_t instructions = block->get_inst_list(); + for(ir::instruction *i: instructions){ + if(dynamic_cast(i)) + continue; + if(std::find(async_write.begin(), async_write.end(), i) == async_write.end() && + dynamic_cast(i)){ + async_write.push_back(i); + } + if(dynamic_cast(i)) + sync_write.insert(i); + ir::barrier_inst* barrier = dynamic_cast(i); + ir::async_wait_inst* async_wait = dynamic_cast(i); + // Get shared memory reads + std::set read; + std::copy_if(i->op_begin(), i->op_end(), std::inserter(read, read.begin()), + [&](ir::value* i){ return i->get_type()->is_tile_ty() && layouts_->get(i)->to_shared();}); + // RAW (async) + val_set_t tmp; + std::copy(async_write.begin(), async_write.end(), std::inserter(tmp, tmp.begin())); + if(intersect_with(read, tmp).size()){ + std::vector groups(read.size()); + std::transform(read.begin(), read.end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);}); + int N = *std::max_element(groups.begin(), groups.end()); + if(N < async_write.size()){ + builder.set_insert_point(i); + async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N); + barrier = (ir::barrier_inst*)builder.create_barrier(); + inserted = true; } } - } - else { - builder.set_insert_point(instr); - builder.create_barrier(); - } -} - -membar::interval_vec_t membar::join(const std::vector& intervals) { - membar::interval_vec_t result; - for(auto x: intervals) - for(interval_t i: x) - result.push_back(i); - return result; -} - -std::pair membar::transfer(ir::basic_block *block, - const interval_vec_t &written_to, - const interval_vec_t &read_from, - std::map>& insert_loc, - std::set& safe_war, - std::vector& to_sync) { - ir::basic_block::inst_list_t instructions = block->get_inst_list(); - interval_vec_t new_written_to = written_to; - interval_vec_t new_read_from = read_from; - - for(ir::instruction *i: instructions){ - interval_vec_t read, written; - get_read_intervals(i, read); - get_written_intervals(i, written); - if(written.size()) - to_sync.push_back(i); - bool read_after_write = intersect(new_written_to, read); - bool write_after_read = intersect(new_read_from, written); - // double buffering - if(safe_war.find(i) != safe_war.end()){ - write_after_read = false; - read_after_write = false; + // RAW, WAR + if(intersect_with(read, sync_write).size() || intersect_with({i}, sync_read).size()){ + builder.set_insert_point(i); + barrier = (ir::barrier_inst*)builder.create_barrier(); + inserted = true; } - // record hazards - if(read_after_write || write_after_read) { - auto is_load_async = [&](ir::instruction *i){ return dynamic_cast(i);}; - auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast(i);}; - bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async); - bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared); - insert_loc.insert({i, {copy_async_wait, barrier}}); - new_written_to.clear(); - new_read_from.clear(); - to_sync.clear(); + // update state of asynchronous copies + if(async_wait){ + int N = async_write.size() - async_wait->get_N(); + async_write.erase(async_write.begin(), async_write.begin() + N); } - std::copy(written.begin(), written.end(), std::back_inserter(new_written_to)); - std::copy(read.begin(), read.end(), std::back_inserter(new_read_from)); + // all the copy_to_shared and read from shared are synchronized after barrier + if(barrier){ + sync_write.clear(); + sync_read.clear(); + } + sync_read.insert(read.begin(), read.end()); + } - return std::make_pair(new_written_to, new_read_from); } void membar::run(ir::module &mod) { @@ -143,35 +134,33 @@ void membar::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()){ std::vector rpo = ir::cfg::reverse_post_order(fn); - std::map written_to; - std::map read_from; - std::vector to_sync; - std::map> insert_locs; - size_t n_inserted_im1 = 0; - bool done = false; + std::map async_writes; + std::map sync_writes; + std::map sync_reads; + std::list pipelined; + bool inserted; do{ + inserted = false; // find barrier location for(ir::basic_block *block: rpo){ - // written to - std::vector pred_written_to; - for(ir::basic_block* pred: block->get_predecessors()) - pred_written_to.push_back(written_to[pred]); - // read from - std::vector pred_read_from; - for(ir::basic_block* pred: block->get_predecessors()) - pred_read_from.push_back(read_from[pred]); - // apply transfer function - auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync); - written_to[block] = result.first; - read_from[block] = result.second; + // join inputs + val_vec_t async_write; + val_set_t sync_write; + val_set_t sync_read; + val_set_t tmp; + for(ir::basic_block* pred: block->get_predecessors()){ + for(ir::value* v: async_writes[pred]) + if(tmp.insert(v).second) + async_write.push_back(v); + sync_write.insert(sync_writes[pred].begin(), sync_writes[pred].end()); + sync_read.insert(sync_reads[pred].begin(), sync_reads[pred].end()); + } + transfer(block, async_write, sync_write, sync_read, safe_war, inserted, builder); + async_writes[block] = async_write; + sync_writes[block] = sync_write; + sync_reads[block] = sync_read; } - size_t n_inserted_i = insert_locs.size(); - done = (n_inserted_im1 == n_inserted_i); - n_inserted_im1 = n_inserted_i; - }while(!done); - for(auto x: insert_locs){ - insert_barrier(x.first, x.second, builder); - } + }while(inserted); } } diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 8caa1b0bc..2855674a9 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -1,7 +1,9 @@ #include +#include #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/codegen/transform/peephole.h" +#include "triton/codegen/analysis/layout.h" namespace triton { namespace codegen{ @@ -109,9 +111,18 @@ bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& build ir::value *ptr = ld->get_pointer_operand(); ir::value *msk = ld->get_mask_operand(); ir::value *val = ld->get_false_value_operand(); - ir::value* new_load = builder.create_masked_load_async(ptr, msk, val); - copy_to_shared->replace_all_uses_with(new_load); - return true; + analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline(); + int nts = layout->nts(layout->get_order()[0]); + int dtsize = value->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + if(nts*dtsize >= 4){ + ir::value* new_load = builder.create_masked_load_async(ptr, msk, val); + copy_to_shared->replace_all_uses_with(new_load); + return true; + } + return false; +// analysis::scanline_layout* layout = layouts_->get(ptr)->to_scanline(); +// std::cout << layout->nts(layout->get_order(0)) << std::endl; +// return true; } @@ -216,11 +227,11 @@ void peephole::run(ir::module &mod) { bool was_modified = false; was_modified = was_modified || rewrite_mult(i, builder); // was_modified = was_modified || rewrite_cts_cfs(i, builder); - was_modified = was_modified || rewrite_trans_phi(i, builder); +// was_modified = was_modified || rewrite_trans_phi(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); -// if(tgt_->as_nvidia()->sm() >= 80) -// was_modified = was_modified || rewrite_load_to_shared(i, builder); + if(tgt_->as_nvidia()->sm() >= 80) + was_modified = was_modified || rewrite_load_to_shared(i, builder); if(was_modified) seen.insert(i); } diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc new file mode 100644 index 000000000..32af28463 --- /dev/null +++ b/lib/codegen/transform/pipeline.cc @@ -0,0 +1,116 @@ +#include +#include +#include "triton/codegen/transform/pipeline.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/instructions.h" +#include "triton/ir/utils.h" + +namespace triton { +namespace codegen{ +namespace transform{ + + +void recursive_deps(ir::value* v, ir::basic_block* block, std::vector& ret){ + ir::instruction* i = dynamic_cast(v); + if(!i || i->get_parent() != block) + return; + if(i->get_id()==ir::INST_PHI) + return; + ret.push_back(i); + for(ir::user* u: i->get_users()) + recursive_deps(u, block, ret); +} + +void pipeline::run(ir::module &mod) { + // *Very* conservative heuristics for pre-fetching. + // A load instruction can be pipelined if: + // - the pointer is a phi node that references a value + // in its basic block (i.e., pointer induction variable) + // - the load has only a single use in a dot instruction + // As more use cases become apparent, this pass will be improved + std::vector> to_pipeline; + ir::for_each_instruction(mod, [&](ir::instruction *i){ + if(auto* load = dynamic_cast(i)){ + ir::phi_node* ptr = dynamic_cast(load->get_pointer_operand()); + auto users = load->get_users(); + if(ptr && ptr->get_incoming_block(1) == ptr->get_parent() + && users.size() == 1 && dynamic_cast(*users.begin())) + to_pipeline.push_back({load, ptr}); + }}); + // do the pipelining + std::vector new_loads; + ir::builder &builder = mod.get_builder(); + for(auto info: to_pipeline){ + ir::load_inst* load = info.first; + ir::phi_node* ptr = info.second; + ir::basic_block* block = load->get_parent(); + ir::basic_block* header = block->get_predecessors()[0]; + auto* block_br = dynamic_cast(block->get_inst_list().back()); + auto* header_br = dynamic_cast(header->get_inst_list().back()); + assert(block_br); + assert(header_br); + ir::type* ty = load->get_type(); + // pre-fetch first iteration + builder.set_insert_point(header->get_inst_list().back()); + ir::value* first_ptr = ptr->get_value_for_block(header); + ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_tile_shapes()); + ir::value* false_value; + if(auto* masked_load = dynamic_cast(load)){ + first_mask = builder.create_and(first_mask, masked_load->get_mask_operand()); + false_value = masked_load->get_false_value_operand(); + } + else + false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_tile_shapes()); + ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value); + // pre-fetch next iteration + builder.set_insert_point(block->get_inst_list().back()); + ir::value* next_ptr = ptr->get_value_for_block(block); + ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_tile_shapes()); + if(auto* masked_load = dynamic_cast(load)) + next_mask = builder.create_and(next_mask, masked_load->get_mask_operand()); + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); + // phi node + builder.set_insert_point(block->get_first_non_phi()); + ir::phi_node* new_load = builder.create_phi(ty, 2); + new_load->add_incoming(first_load, header); + new_load->add_incoming(next_load, block); + load->replace_all_uses_with(new_load); + new_loads.push_back(new_load); + } + + + // try to move dot_inst after loads + // for better overlap of io and compute + struct move_config_t{ + std::vector insts; + ir::load_inst* dst; + }; + std::map to_move; + + if(has_copy_async_){ + for(ir::function* fn: mod.get_function_list()) + for(ir::basic_block* bb: fn->blocks()) + for(ir::instruction* inst: bb->get_inst_list()){ + if(auto* i = dynamic_cast(inst)) + recursive_deps(i, bb, to_move[bb].insts); + if(auto* i = dynamic_cast(inst)) + to_move[bb].dst = i; + } + + for(auto& x: to_move){ + builder.set_insert_point_after(x.second.dst); + for(ir::instruction* i: x.second.insts){ + x.first->erase(i); + builder.insert(i); + } + } + } + + +} + +} +} +} diff --git a/lib/codegen/transform/reassociate.cc b/lib/codegen/transform/reassociate.cc index 20241e70e..01293e1a5 100644 --- a/lib/codegen/transform/reassociate.cc +++ b/lib/codegen/transform/reassociate.cc @@ -22,6 +22,8 @@ inline ir::instruction* reassociate::is_bin_add(ir::value *x) { inline bool is_cst(ir::value *x) { if(dynamic_cast(x)) return true; + if(dynamic_cast(x)) + return true; if(auto *v = dynamic_cast(x)) return is_cst(v->get_operand(0)); return false; diff --git a/lib/driver/kernel.cc b/lib/driver/kernel.cc index 2b340bc4a..5c57e01a6 100755 --- a/lib/driver/kernel.cc +++ b/lib/driver/kernel.cc @@ -70,7 +70,21 @@ host_kernel::host_kernel(driver::module* program, const char *name): kernel(prog cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(program, CUfunction(), true) { dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name); -// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED); + dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED); + // properties + int shared_total, shared_optin, shared_static; + int n_spills, n_reg; + CUdevice dev; + dispatch::cuCtxGetDevice(&dev); + dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev); + dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev); + dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_); + dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_); + dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_); + if (shared_optin > 49152){ +// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl; + dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); + } } } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 7c7286d3e..67c08edc1 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -282,6 +282,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, void cu_module::init_from_ptx(const std::string& ptx) { // JIT compile source-code +// std::cout << ptx << std::endl; try{ // // compile ptx with ptxas diff --git a/lib/driver/stream.cc b/lib/driver/stream.cc index ff349c9ce..c49911e4c 100755 --- a/lib/driver/stream.cc +++ b/lib/driver/stream.cc @@ -76,7 +76,7 @@ void host_stream::synchronize() { hst_->args.clear(); } -void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size) { +void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size, size_t) { auto hst = kernel->module()->hst(); hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]); char* params = new char[args_size]; @@ -113,13 +113,13 @@ void cu_stream::synchronize() { dispatch::cuStreamSynchronize(*cu_); } -void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size) { +void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, void* args, size_t args_size, size_t shared_mem) { void *config[] = { CU_LAUNCH_PARAM_BUFFER_POINTER, args, CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, CU_LAUNCH_PARAM_END }; - dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config); + dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], shared_mem, *cu_, nullptr, config); } void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 2122b57bd..92db216ff 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -45,6 +45,9 @@ void builder::set_insert_point(basic_block *block){ // convenience functions //===----------------------------------------------------------------------===// +value *builder::get_int1(bool val) +{ return constant_int::get(type::get_int1_ty(ctx_), val); } + value *builder::get_int32(int32_t val) { return constant_int::get(type::get_int32_ty(ctx_), val);} @@ -372,8 +375,8 @@ value *builder::create_barrier(const std::string &name) { return insert(barrier_inst::create(ctx_, name)); } -value *builder::create_async_wait() { - return insert(async_wait_inst::create(ctx_)); +value *builder::create_async_wait(int N) { + return insert(async_wait_inst::create(ctx_, N)); } } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 5e15e83c7..8e197648d 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -45,6 +45,12 @@ phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, ins blocks_.reserve(num_reserved); } +value* phi_node::get_value_for_block(basic_block * block) { + auto it = std::find(blocks_.begin(), blocks_.end(), block); + size_t n = std::distance(blocks_.begin(), it); + return get_incoming_value(n); +} + // Set incoming value void phi_node::set_incoming_value(unsigned i, value *v){ assert(v && "PHI node got a null value!"); @@ -818,12 +824,11 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru return new barrier_inst(ctx, name, next); } -async_wait_inst::async_wait_inst(context &ctx, const std::string &name, - instruction *next) - : instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { } +async_wait_inst::async_wait_inst(context &ctx, int N, const std::string &name, instruction *next) + : instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next), N_(N) { } -async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) { - return new async_wait_inst(ctx, name, next); +async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string &name, instruction *next) { + return new async_wait_inst(ctx, N, name, next); } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 59effcef1..5e2caa062 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -15,10 +15,10 @@ #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/reassociate.h" -#include "triton/codegen/transform/reorder.h" #include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/disassociate.h" #include "triton/codegen/selection/generator.h" +#include "triton/codegen/transform/pipeline.h" #include "triton/runtime/function.h" #include "triton/lang/cpp.h" #include "triton/lang/parser.h" @@ -149,6 +149,7 @@ void kernel::init_ker(){ codegen::analysis::align align; codegen::analysis::axes axes; codegen::transform::cts cts(cts_use_async); + codegen::transform::pipeline pipeline(cts_use_async); codegen::transform::disassociate disassociate; codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); codegen::analysis::liveness liveness(&layouts); @@ -156,19 +157,24 @@ void kernel::init_ker(){ codegen::analysis::allocation allocation(&liveness); codegen::transform::membar barriers(&liveness, &layouts, &allocation); codegen::transform::dce dce; - codegen::transform::peephole peephole(target.get()); + codegen::transform::peephole peephole(target.get(), &layouts); codegen::transform::reassociate reassociate; codegen::transform::coalesce coalesce(&align, &layouts); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); // run passes dce.run(*ir_); + pipeline.run(*ir_); + dce.run(*ir_); disassociate.run(*ir_); dce.run(*ir_); + align.run(*ir_); + axes.run(*ir_); + layouts.run(*ir_); peephole.run(*ir_); dce.run(*ir_); - align.run(*ir_); if(target->is_gpu()) cts.run(*ir_); + align.run(*ir_); axes.run(*ir_); layouts.run(*ir_); coalesce.run(*ir_); @@ -179,6 +185,11 @@ void kernel::init_ker(){ reassociate.run(*ir_); cts.run(*ir_); } + dce.run(*ir_); +// ir::print(*ir_, std::cout); + align.run(*ir_); + axes.run(*ir_); + layouts.run(*ir_); peephole.run(*ir_); dce.run(*ir_); align.run(*ir_); @@ -187,8 +198,9 @@ void kernel::init_ker(){ swizzle.run(*ir_); liveness.run(*ir_); allocation.run(*ir_); - if(allocation.allocated_size() > dev_->max_shared_memory()) - throw exception::out_of_shared_memory(); + shared_mem_ = allocation.allocated_size(); +// if(allocation.allocated_size() > dev_->max_shared_memory()) +// throw exception::out_of_shared_memory(); barriers.run(*ir_); isel.visit(*ir_, *llvm); //if(res->spilled() > 256) @@ -224,7 +236,7 @@ void kernel::operator()(void *args, size_t args_size, driver::stream *stream, co for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; // enqueue - stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size); + stream->enqueue(&*ker_, grid, {(size_t)opt.num_warps * 32, 1, 1}, args, args_size, shared_mem_); } std::string kernel::get_asm(asm_mode_t mode) { @@ -348,7 +360,7 @@ kernel* function::autotune(void* args, size_t args_size, const grid_fn_ty& grid_ while(grid.size() < 3) grid.push_back(1); double ts = tools::bench([&]() { (*current)(args, args_size, stream, grid); }, - stream, true); + stream, 5, 20); ret = (ts < best_ts) ? current : ret; best_ts = std::min(ts, best_ts); } diff --git a/python/bench/bench_matmul.py b/python/bench/bench_matmul.py index 457ebf4a6..25da0f68e 100644 --- a/python/bench/bench_matmul.py +++ b/python/bench/bench_matmul.py @@ -2,58 +2,74 @@ import triton import torch # square benchmarks -nt = {False: 'n', True: 't'} +nt = {False: "n", True: "t"} square_confs = [ triton.testing.Benchmark( - x_names = ['M', 'N', 'K'], - x_vals = [128, 256, 512, 1024, 2048, 3072, 4096, 6144], - y_name = 'provider', - y_vals = ['torch', 'triton', 'cutlass'], - y_lines = ['Torch', 'Triton', 'CUTLASS'], - ylabel = 'TFLOPS', - loglog = False, - plot_name = f'matmul-square-{nt[AT]}{nt[BT]}', - args = {'AT': False, 'BT': False, 'dtype': torch.float16} - )\ - for AT in [False, True] for BT in [False, True] + x_names=["M", "N", "K"], + x_vals=[512 * i for i in range(1, 16)], + y_name="provider", + y_vals=["torch", "triton", "cutlass"], + y_lines=["Torch", "Triton", "CUTLASS"], + ylabel="TFLOPS", + loglog=False, + plot_name=f"matmul-square-{nt[AT]}{nt[BT]}", + args={"AT": AT, "BT": BT, "dtype": torch.float16}, + ) for AT in [False, True] for BT in [False, True] ] @triton.testing.perf_report(square_confs) -def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=5): +def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=20): import os - a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=dtype) / K**.5 - b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=dtype) / K**.5 - if AT: a = a.t() - if BT: b = b.t() + + a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype) + b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype) + if AT: + a = a.t() + if BT: + b = b.t() num_flops = 2 * M * N * K - if provider == 'torch': + if provider == "torch": torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep) torch_tflops = num_flops / torch_ms * 1e-9 return torch_tflops - if provider == 'triton': + if provider == "triton": triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep) triton_tflops = num_flops / triton_ms * 1e-9 return triton_tflops - if provider == 'cutlass' and 'CUTLASS_PROFILER' in os.environ: + if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ: import subprocess import tempfile import pandas as pd + # run program specified by CUTLASS_PROFILER env variable - layout_a = 'column' if AT else 'row' - layout_b = 'column' if BT else 'row' + layout_a = "column" if AT else "row" + layout_b = "column" if BT else "row" # create temporary file name fd, fname = tempfile.mkstemp() # run program and gets its output - cmd = [os.environ['CUTLASS_PROFILER'], f'--m={M}', f'--n={N}', f'--k={K}', f'--A=f16:{layout_a}', f'--B=f16:{layout_b}', \ - '--C=f16:column', '--accum=f32', '--operation=gemm', '--verification-enabled=false', f'--warmup-iterations={warmup}', \ - f'--profiling-iterations={rep}', f'--output={fname}', '--verbose=false'] + cmd = [ + os.environ["CUTLASS_PROFILER"], + f"--m={M}", + f"--n={N}", + f"--k={K}", + f"--A=f16:{layout_a}", + f"--B=f16:{layout_b}", + "--C=f16:column", + "--accum=f32", + "--operation=gemm", + "--verification-enabled=false", + f"--warmup-iterations={warmup}", + f"--profiling-iterations={rep}", + f"--output={fname}", + "--verbose=false", + ] # run cmd subprocess.run(cmd, stdout=subprocess.PIPE) # read CSV output - df_c = pd.read_csv(f'{fname}.gemm.csv') - cutlass_tflops = max(df_c['GFLOPs']) / 1e3 + df_c = pd.read_csv(f"{fname}.gemm.csv") + cutlass_tflops = max(df_c["GFLOPs"]) / 1e3 return cutlass_tflops return None -if __name__ == '__main__': +if __name__ == "__main__": bench_op.run() diff --git a/python/setup.py b/python/setup.py index a64c08f61..4d4bfaee3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -15,21 +15,21 @@ import distutils.spawn import torch def find_llvm(): - versions = ['-10', '-10.0', ''] - supported = ['llvm-config{v}'.format(v=v) for v in versions] + versions = ["-10", "-10.0", ""] + supported = ["llvm-config{v}".format(v=v) for v in versions] paths = [distutils.spawn.find_executable(cfg) for cfg in supported] paths = [p for p in paths if p is not None] if paths: return paths[0] - config = distutils.spawn.find_executable('llvm-config') - instructions = 'Please install llvm-10-dev' + config = distutils.spawn.find_executable("llvm-config") + instructions = "Please install llvm-10-dev" if config is None: - raise RuntimeError('Could not find llvm-config. ' + instructions) - version = os.popen('{config} --version'.format(config=config)).read() - raise RuntimeError('Version {v} not supported. '.format(v=version) + instructions) + raise RuntimeError("Could not find llvm-config. " + instructions) + version = os.popen("{config} --version".format(config=config)).read() + raise RuntimeError("Version {v} not supported. ".format(v=version) + instructions) class CMakeExtension(Extension): - def __init__(self, name, path, sourcedir=''): + def __init__(self, name, path, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) self.path = path @@ -37,84 +37,84 @@ class CMakeExtension(Extension): class CMakeBuild(build_ext): def run(self): try: - out = subprocess.check_output(['cmake', '--version']) + out = subprocess.check_output(["cmake", "--version"]) except OSError: raise RuntimeError("CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions)) if platform.system() == "Windows": - cmake_version = LooseVersion(re.search(r'version\s*([\d.]+)', out.decode()).group(1)) - if cmake_version < '3.1.0': + cmake_version = LooseVersion(re.search(r"version\s*([\d.]+)", out.decode()).group(1)) + if cmake_version < "3.1.0": raise RuntimeError("CMake >= 3.1.0 is required on Windows") for ext in self.extensions: self.build_extension(ext) def build_extension(self, ext): - #self.debug = True + # self.debug = True + self.debug = False extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) # python directories python_include_dirs = distutils.sysconfig.get_python_inc() - python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR') + python_lib_dirs = distutils.sysconfig.get_config_var("LIBDIR") torch_include_dirs = include_paths(True) torch_library_dirs = library_paths(True) cxx11abi = str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) cmake_args = [ - '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, - '-DBUILD_TUTORIALS=OFF', - '-DBUILD_PYTHON_MODULE=ON', + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, + "-DBUILD_TUTORIALS=OFF", + "-DBUILD_PYTHON_MODULE=ON", #'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, - '-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)), - '-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)), - '-DTORCH_CXX11_ABI=' + cxx11abi, - '-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton', - '-DLLVM_CONFIG=' + find_llvm() + "-DPYTHON_INCLUDE_DIRS=" + ";".join([python_include_dirs] + include_paths(True)), + "-DPYTHON_LINK_DIRS=" + ";".join(library_paths(True)), + "-DTORCH_CXX11_ABI=" + cxx11abi, + "-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton", + "-DLLVM_CONFIG=" + find_llvm(), ] # configuration - cfg = 'Debug' if self.debug else 'Release' - cfg = 'Release' - build_args = ['--config', cfg] + cfg = "Debug" if self.debug else "Release" + build_args = ["--config", cfg] if platform.system() == "Windows": - cmake_args += ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}'.format(cfg.upper(), extdir)] + cmake_args += ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir)] if sys.maxsize > 2**32: - cmake_args += ['-A', 'x64'] - build_args += ['--', '/m'] + cmake_args += ["-A", "x64"] + build_args += ["--", "/m"] else: - cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] - build_args += ['--', '-j4'] + cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] + build_args += ["--", "-j4"] env = os.environ.copy() if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) - subprocess.check_call(['cmake', sourcedir] + cmake_args, cwd=self.build_temp, env=env) - subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp) + sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "src")) + subprocess.check_call(["cmake", sourcedir] + cmake_args, cwd=self.build_temp, env=env) + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp) setup( - name='triton', - version='1.0.0', - author='Philippe Tillet', - author_email='phil@openai.com', - description='A language and compiler for custom Deep Learning operations', - long_description='', - packages=['triton', 'triton/_C', 'triton/ops', 'triton/ops/blocksparse'], - install_requires=['numpy', 'torch'], - package_data={'triton/ops': ['*.c'], 'triton/ops/blocksparse': ['*.c']}, + name="triton", + version="1.0.0", + author="Philippe Tillet", + author_email="phil@openai.com", + description="A language and compiler for custom Deep Learning operations", + long_description="", + packages=["triton", "triton/_C", "triton/ops", "triton/ops/blocksparse"], + install_requires=["numpy", "torch"], + package_data={"triton/ops": ["*.c"], "triton/ops/blocksparse": ["*.c"]}, include_package_data=True, - ext_modules=[CMakeExtension('triton', 'triton/_C/')], - cmdclass={'build_ext': CMakeBuild}, + ext_modules=[CMakeExtension("triton", "triton/_C/")], + cmdclass={"build_ext": CMakeBuild}, zip_safe=False, # for PyPI - keywords=['Compiler', 'Deep Learning'], - url='https://github.com/ptillet/triton/', - download_url='https://github.com/ptillet/triton/archive/v0.1.tar.gz', + keywords=["Compiler", "Deep Learning"], + url="https://github.com/ptillet/triton/", + download_url="https://github.com/ptillet/triton/archive/v0.1.tar.gz", classifiers=[ - 'Development Status :: 3 - Alpha', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package - 'Intended Audience :: Developers', # Define that your audience are developers - 'Topic :: Software Development :: Build Tools', - 'License :: OSI Approved :: MIT License', # Again, pick a license - 'Programming Language :: Python :: 3.6', + "Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package + "Intended Audience :: Developers", # Define that your audience are developers + "Topic :: Software Development :: Build Tools", + "License :: OSI Approved :: MIT License", # Again, pick a license + "Programming Language :: Python :: 3.6", ], ) diff --git a/python/test/test_blocksparse.py b/python/test/test_blocksparse.py index 6b12371eb..70ae1b8ce 100644 --- a/python/test/test_blocksparse.py +++ b/python/test/test_blocksparse.py @@ -2,29 +2,17 @@ import torch import triton import pytest - @pytest.mark.parametrize( "MODE, TRANS_A, TRANS_B, BLOCK", - [ - (mode, at, bt, block) - for mode in ["sdd", "dsd", "dds"] - for at in [False, True] - for bt in [False, True] - for block in [16, 32, 64] - ], + [(mode, at, bt, block) for mode in ["sdd", "dsd", "dds"] for at in [False, True] for bt in [False, True] + for block in [16, 32, 64]], ) -def test_matmul( - MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384 -): +def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE=torch.float16, Z=3, H=2, M=128, N=256, K=384): # set seed torch.random.manual_seed(0) # create inputs - a = torch.randn( - (Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda" - ) - b = torch.randn( - (Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda" - ) + a = torch.randn((Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda") + b = torch.randn((Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda") shape = { "sdd": (M, N), "dsd": (a.shape[2], a.shape[3]), @@ -32,9 +20,7 @@ def test_matmul( }[MODE] layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) # triton result - op = triton.ops.blocksparse.matmul( - layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B - ) + op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b rc = op(ra, rb) @@ -49,7 +35,6 @@ def test_matmul( # compare assert triton.testing.allclose(rc, tc) - @pytest.mark.parametrize( "BLOCK, WIDTH", [(block, width) for block in [32] for width in [256, 576, 1024, 1792]], @@ -62,12 +47,8 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): # create inputs layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") - at_mask = torch.randint( - low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda" - ) - kp_mask = torch.randint( - low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda" - ) + at_mask = torch.randint(low=0, high=2, size=(N, N), dtype=torch.bool, requires_grad=False, device="cuda") + kp_mask = torch.randint(low=0, high=2, size=(Z, N), dtype=DTYPE, requires_grad=False, device="cuda") kp_mask[kp_mask == 1.0] = float("-inf") # triton result op = triton.ops.blocksparse.softmax(layout, BLOCK) @@ -94,7 +75,6 @@ def test_softmax(BLOCK, WIDTH, DTYPE=torch.float16): # compare assert triton.testing.allclose(ry, ty) - def test_attention_fwd_bwd( input_scale=1.0, tol=2e-2, @@ -108,10 +88,7 @@ def test_attention_fwd_bwd( # inputs qkv_shape = (batch_size, n_heads, n_ctx, 64) qkvs = [ - torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True) - .to(dtype) - .cuda() - for _ in range(3) + torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True).to(dtype).cuda() for _ in range(3) ] attn_mask = torch.tril( torch.ones( @@ -129,11 +106,9 @@ def test_attention_fwd_bwd( query.retain_grad() key.retain_grad() value.retain_grad() - attn_out = triton_attention( - layout, block, attn_mask, query=query, key=key, value=value, scale=scale - ) + attn_out = triton_attention(layout, block, attn_mask, query=query, key=key, value=value, scale=scale) # ad hoc loss - loss = (attn_out ** 2).mean() + loss = (attn_out**2).mean() loss.backward() grads = [query.grad, key.grad, value.grad] @@ -148,17 +123,16 @@ def test_attention_fwd_bwd( probs = torch.softmax(scores, dim=-1) torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) # ad hoc loss - torch_loss = (torch_attn_out ** 2).mean() + torch_loss = (torch_attn_out**2).mean() torch_loss.backward() torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] # comparison - print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") + # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") torch.testing.assert_allclose(loss, torch_loss, rtol=tol, atol=tol) for g1, g2 in zip(grads, torch_grads): torch.testing.assert_allclose(g1, g2, rtol=tol, atol=tol) - def triton_attention( layout, block: int, @@ -168,12 +142,8 @@ def triton_attention( value: torch.Tensor, scale: float, ): - sparse_dot_sdd_nt = triton.ops.blocksparse.matmul( - layout, block, "sdd", trans_a=False, trans_b=True - ) - sparse_dot_dsd_nn = triton.ops.blocksparse.matmul( - layout, block, "dsd", trans_a=False, trans_b=False - ) + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False) sparse_softmax = triton.ops.blocksparse.softmax( layout, block, diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index 194e3d422..b3ee58370 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -4,7 +4,7 @@ import triton import torch @pytest.mark.parametrize( - "TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE", + "TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE", itertools.chain(*[ [ # 1 warp @@ -17,14 +17,14 @@ import torch (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), - # 2 warp + # # 2 warp (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), - # 4 warp + # # 4 warp (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), @@ -40,24 +40,28 @@ import torch (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), # variable input - (128, 128, 32, 1, 4, 256, 256, 256, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE), (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE) - ] for DTYPE in ['float16'] for AT in [False, True] for BT in [False, True] - ])) -def test_op(TM, TN, TK, TZ, NWARP, M, N, K, AT, BT, DTYPE): - DTYPE = {'float16': torch.float16, 'float32': torch.float32}[DTYPE] + (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE), + ] for DTYPE in ["float16"] for AT in [False, True] for BT in [False, True] + ]), +) +def test_op(TM, TN, TK, SPLITK, NWARP, M, N, K, AT, BT, DTYPE): + DTYPE = {"float16": torch.float16, "float32": torch.float32}[DTYPE] torch.manual_seed(0) triton.ops._matmul._kernels = dict() - triton.ops._matmul._CONFIGS = [({'TM': str(TM), 'TN': str(TN), 'TK': str(TK), 'TZ': str(TZ)}, NWARP)] - if M is None: M = TM - if N is None: N = TN - if K is None: K = TK * TZ - a = torch.randn((K, M) if AT else (M, K), device='cuda', dtype=DTYPE) / K**.5 - b = torch.randn((N, K) if BT else (K, N), device='cuda', dtype=DTYPE) / K**.5 + triton.ops._matmul._CONFIGS = [({"TM": str(TM), "TN": str(TN), "TK": str(TK), "SPLITK": str(SPLITK)}, NWARP)] + if M is None: + M = TM + if N is None: + N = TN + if K is None: + K = TK * SPLITK + a = torch.randn((K, M) if AT else (M, K), device="cuda", dtype=DTYPE) + b = torch.randn((N, K) if BT else (K, N), device="cuda", dtype=DTYPE) a = a.t() if AT else a b = b.t() if BT else b th_c = torch.matmul(a, b) tt_c = triton.ops.matmul(a, b) - assert triton.testing.allclose(th_c, tt_c) \ No newline at end of file + assert triton.testing.allclose(th_c, tt_c) diff --git a/python/triton/ops/blocksparse/matmul.c b/python/triton/ops/blocksparse/matmul.c index e520b81e6..a45d072b3 100644 --- a/python/triton/ops/blocksparse/matmul.c +++ b/python/triton/ops/blocksparse/matmul.c @@ -1,198 +1,199 @@ - __global__ void NAME (TYPE* A __readonly __noalias __aligned(16), - TYPE* B __readonly __noalias __aligned(16), - TYPE* C __noalias __aligned(16), - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc __multipleof(8), - long stride_za __multipleof(8), - long stride_zb __multipleof(8), - long stride_zc __multipleof(8), - long stride_ha __multipleof(8), - long stride_hb __multipleof(8), - long stride_hc __multipleof(8), - int DS0, int DS1, - int SDD_K __multipleof(16), - int SDD_off_width, - int* lut, int* locks, int nlocks) { - /* ---------------- */ - /* Prologue */ - /* ---------------- */ - // program ids - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); - int pidz = get_program_id(2); +__global__ void NAME(TYPE *A __readonly __noalias __aligned(16), + TYPE *B __readonly __noalias __aligned(16), + TYPE *C __noalias __aligned(16), + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc __multipleof(8), + long stride_za __multipleof(8), + long stride_zb __multipleof(8), + long stride_zc __multipleof(8), + long stride_ha __multipleof(8), + long stride_hb __multipleof(8), + long stride_hc __multipleof(8), + int DS0, int DS1, + int SDD_K __multipleof(16), + int SDD_off_width, + int *lut, int *locks, int nlocks) { + /* ---------------- */ + /* Prologue */ + /* ---------------- */ + // program ids + int pid0 = get_program_id(0); + int pid1 = get_program_id(1); + int pidz = get_program_id(2); #ifdef SDD - // load LUT header - pid1 = pid1 + SDD_off_width; - int blockidm[TM] = (0 ... TM) / BLOCK; - int blockidn[TN] = (0 ... TN) / BLOCK; - int offlutm[TM] = blockidm*(TN/BLOCK)*4; - int offlutn[TN] = blockidn*4; - int *header = lut + pid1 * (TM/BLOCK) * (TN/BLOCK) * 4; - int z = *(header + 0); - int i[TM] = *(header + 1 + offlutm); - int j[TN] = *(header + 2 + offlutn); - int AS1 = SDD_K / TZ; - int lockid = select(TZ > 1, 1, 0); - int offka = pid0 * AS1; - int offkb = pid0 * AS1; - int offmc = 0; - int offnc = 0; - int offpa = 0; - int offpb = 0; - int maxid = TZ; - int offhc = 0; - int offha = z; - int offhb = z; - int ram[TM] = i*BLOCK + ((0 ... TM) % BLOCK); - int rbn[TN] = j*BLOCK + ((0 ... TN) % BLOCK); + // load LUT header + pid1 = pid1 + SDD_off_width; + int blockidm[TM] = (0 ... TM) / BLOCK; + int blockidn[TN] = (0 ... TN) / BLOCK; + int offlutm[TM] = blockidm * (TN / BLOCK) * 4; + int offlutn[TN] = blockidn * 4; + int *header = lut + pid1 * (TM / BLOCK) * (TN / BLOCK) * 4; + int z = *(header + 0); + int i[TM] = *(header + 1 + offlutm); + int j[TN] = *(header + 2 + offlutn); + int AS1 = SDD_K / TZ; + int lockid = select(TZ > 1, 1, 0); + int offka = pid0 * AS1; + int offkb = pid0 * AS1; + int offmc = 0; + int offnc = 0; + int offpa = 0; + int offpb = 0; + int maxid = TZ; + int offhc = 0; + int offha = z; + int offhb = z; + int ram[TM] = i * BLOCK + ((0 ... TM) % BLOCK); + int rbn[TN] = j * BLOCK + ((0 ... TN) % BLOCK); #else - // load LUT header - int *header = lut + pid0 * 6; - int offset = *(header + 0); - int AS1 = *(header + 1); - int column = *(header + 2); - int depth = *(header + 3); - int lockid = *(header + 4); - int maxid = *(header + 5); - int *pinc = lut + offset; - int offhc = depth; + // load LUT header + int *header = lut + pid0 * 6; + int offset = *(header + 0); + int AS1 = *(header + 1); + int column = *(header + 2); + int depth = *(header + 3); + int lockid = *(header + 4); + int maxid = *(header + 5); + int *pinc = lut + offset; + int offhc = depth; #ifdef DSD - // output offset - int offnc = pid1 * TN; - int offmc = column * TM; - int offpc = 0; - // dense input offset - int offnb = pid1 * TN; - int offkb __multipleof(8) = *pinc; - int offpb = 0; - // sparse input offset - int offma = 0; - int offka = 0; - long offpa __multipleof(8) = *(pinc + 1); - offpa = offpa * BLOCK * BLOCK; - int offha = 0; - int offhb = depth; + // output offset + int offnc = pid1 * TN; + int offmc = column * TM; + int offpc = 0; + // dense input offset + int offnb = pid1 * TN; + int offkb __multipleof(8) = *pinc; + int offpb = 0; + // sparse input offset + int offma = 0; + int offka = 0; + long offpa __multipleof(8) = *(pinc + 1); + offpa = offpa * BLOCK * BLOCK; + int offha = 0; + int offhb = depth; #endif #ifdef DDS - // output offset - int offmc = pid1 * TM; - int offnc = column * TN; - int offpc = 0; - // dense input offset - int offma = pid1 * TM; - int offka __multipleof(8) = *pinc; - int offpa = 0; - // sparse input offset - int offnb = 0; - int offkb = 0; - long offpb __multipleof(8) = *(pinc + 1); - offpb = offpb * BLOCK * BLOCK; - int offha = depth; - int offhb = 0; + // output offset + int offmc = pid1 * TM; + int offnc = column * TN; + int offpc = 0; + // dense input offset + int offma = pid1 * TM; + int offka __multipleof(8) = *pinc; + int offpa = 0; + // sparse input offset + int offnb = 0; + int offkb = 0; + long offpb __multipleof(8) = *(pinc + 1); + offpb = offpb * BLOCK * BLOCK; + int offha = depth; + int offhb = 0; #endif - int ram[TM] = offma + 0 ... TM; - int rbn[TN] = offnb + 0 ... TN; + int ram[TM] = offma + 0 ... TM; + int rbn[TN] = offnb + 0 ... TN; #endif - // initialize a, b pointers - int rka[TK] = offka + 0 ... TK; - int rkb[TK] = offkb + 0 ... TK; - TYPE* pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka[newaxis, :] * STRIDE_AK; - TYPE* pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn[newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK; + // initialize a, b pointers + int rka[TK] = offka + 0 ... TK; + int rkb[TK] = offkb + 0 ... TK; + TYPE *pa[TM, TK] = A + pidz * stride_za + offha * stride_ha + offpa + ram[:, newaxis] * STRIDE_AM + rka [newaxis, :] * STRIDE_AK; + TYPE *pb[TK, TN] = B + pidz * stride_zb + offhb * stride_hb + offpb + rbn [newaxis, :] * STRIDE_BN + rkb[:, newaxis] * STRIDE_BK; + // pre-fetch +#ifdef DDS + bool checkam[TM, TK] = ram[:, newaxis] < DS0; +#else + bool checkam[TM, TK] = AS1 > 0; +#endif +#ifdef DSD + bool checkbn[TK, TN] = rbn [newaxis, :] < DS0; +#else + bool checkbn[TK, TN] = AS1 > 0; +#endif + TYPE a[TM, TK] = checkam ? *pa : 0; + TYPE b[TK, TN] = checkbn ? *pb : 0; + + /* ---------------- */ + /* Inner Loop */ + /* ---------------- */ + // create result tile + float acc[TM, TN] = 0; + int step = TK; + for (int k = AS1; k > 0; k -= step) { + acc += a @b; + // update pointers +#ifdef SDD + int inc_a = TK * STRIDE_AK; + int inc_b = TK * STRIDE_BK; +#else + pinc += 2; +#ifdef DSD + int inc_b __multipleof(8) = *pinc; + int inc_a __multipleof(8) = *(pinc + 1); + inc_b = inc_b * STRIDE_BK; +#endif +#ifdef DDS + int inc_a __multipleof(8) = *pinc; + int inc_b __multipleof(8) = *(pinc + 1); + inc_a = inc_a * STRIDE_AK; +#endif +#endif + pa += inc_a; + pb += inc_b; // pre-fetch -#ifdef DDS - bool checkam[TM, TK] = ram[:, newaxis] < DS0; -#else - bool checkam[TM, TK] = AS1 > 0; -#endif -#ifdef DSD - bool checkbn[TK, TN] = rbn[newaxis, :] < DS0; -#else - bool checkbn[TK, TN] = AS1 > 0; -#endif - TYPE a[TM, TK] = checkam ? *pa : 0; - TYPE b[TK, TN] = checkbn ? *pb : 0; + bool checkak[TM, TK] = k > TK; + bool checkbk[TK, TN] = k > TK; + bool checka[TM, TK] = checkam && checkak; + bool checkb[TK, TN] = checkbk && checkbn; + a = *? (checka)pa; + b = *? (checkb)pb; + } + TYPE c[TM, TN] = acc; - /* ---------------- */ - /* Inner Loop */ - /* ---------------- */ - // create result tile - float acc[TM, TN] = 0; - int step = TK; - for(int k = AS1; k > 0; k -= step) { - acc += a @ b; - // update pointers + /* ---------------- */ + /* Epilogue */ + /* ---------------- */ + // initialize c pointers #ifdef SDD - int inc_a = TK * STRIDE_AK; - int inc_b = TK * STRIDE_BK; + bool checkc[TM, TN] = 1; + // rematerialize + int rr_blockidm[TM] = (0 ... TM) / BLOCK; + int rr_blockidn[TN] = (0 ... TN) / BLOCK; + int rr_offlutm[TM] = rr_blockidm * (TN / BLOCK) * 4; + int rr_offlutn[TN] = rr_blockidn * 4; + int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn [newaxis, :]; + int bkid[TM, TN] = *(header + off_bkid); + long offpc[TM, TN] = bkid * BLOCK * BLOCK; + // range within blocks + int rcm[TM] = (0 ... TM) % BLOCK; + int rcn[TN] = (0 ... TN) % BLOCK; #else - pinc += 2; + int rcm[TM] = offmc + 0 ... TM; + int rcn[TN] = offnc + 0 ... TN; #ifdef DSD - int inc_b __multipleof(8) = *pinc; - int inc_a __multipleof(8) = *(pinc + 1); - inc_b = inc_b * STRIDE_BK; + bool checkc[TM, TN] = rcn [newaxis, :] < DS0; #endif #ifdef DDS - int inc_a __multipleof(8) = *pinc; - int inc_b __multipleof(8) = *(pinc + 1); - inc_a = inc_a * STRIDE_AK; + bool checkc[TM, TN] = rcm[:, newaxis] < DS0; #endif #endif - pa += inc_a; - pb += inc_b; - // pre-fetch - bool checkak[TM, TK] = k > TK; - bool checkbk[TK, TN] = k > TK; - bool checka[TM, TK] = checkam && checkak; - bool checkb[TK, TN] = checkbk && checkbn; - a = *?(checka)pa; - b = *?(checkb)pb; - } - TYPE c[TM, TN] = acc; - - /* ---------------- */ - /* Epilogue */ - /* ---------------- */ - // initialize c pointers -#ifdef SDD - bool checkc[TM, TN] = 1; - // rematerialize - int rr_blockidm[TM] = (0 ... TM) / BLOCK; - int rr_blockidn[TN] = (0 ... TN) / BLOCK; - int rr_offlutm[TM] = rr_blockidm*(TN/BLOCK)*4; - int rr_offlutn[TN] = rr_blockidn*4; - int off_bkid[TM, TN] = 3 + rr_offlutm[:, newaxis] + rr_offlutn[newaxis, :]; - int bkid[TM, TN] = *(header + off_bkid); - long offpc[TM, TN] = bkid * BLOCK * BLOCK; - // range within blocks - int rcm[TM] = (0 ... TM) % BLOCK; - int rcn[TN] = (0 ... TN) % BLOCK; -#else - int rcm[TM] = offmc + 0 ... TM; - int rcn[TN] = offnc + 0 ... TN; -#ifdef DSD - bool checkc[TM, TN] = rcn[newaxis, :] < DS0; -#endif -#ifdef DDS - bool checkc[TM, TN] = rcm[:, newaxis] < DS0; -#endif -#endif - TYPE* pc[TM, TN] = C + offpc + offhc*stride_hc + pidz*stride_zc + rcm[:, newaxis]*STRIDE_CM + rcn[newaxis, :]*STRIDE_CN; - // write-back directly - if(lockid == 0) { - *?(checkc) pc = c; - } - // accumulate partial result using spin-locks - else { - int *plock = locks + get_program_id(2)*nlocks*get_num_programs(1) + get_program_id(1)*nlocks + lockid - 1; - int *pcount = plock + get_num_programs(2)*get_num_programs(1)*nlocks; - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % maxid); - atomic_xchg(plock, 0); - } - } \ No newline at end of file + TYPE *pc[TM, TN] = C + offpc + offhc * stride_hc + pidz * stride_zc + rcm[:, newaxis] * STRIDE_CM + rcn [newaxis, :] * STRIDE_CN; + // write-back directly + if (lockid == 0) { + *? (checkc)pc = c; + } + // accumulate partial result using spin-locks + else { + int *plock = locks + get_program_id(2) * nlocks * get_num_programs(1) + get_program_id(1) * nlocks + lockid - 1; + int *pcount = plock + get_num_programs(2) * get_num_programs(1) * nlocks; + for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)) + ; + int count = *pcount; + if (count == 0) + *? (checkc)pc = c; + else + *? (checkc)pc = c + *? (checkc)pc; + atomic_xchg(pcount, (count + 1) % maxid); + atomic_xchg(plock, 0); + } +} \ No newline at end of file diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 03dc32f21..e874eee89 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -10,454 +10,416 @@ src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) # MAIN API # ############## class _matmul(torch.autograd.Function): - - sdd_cache = dict() - dsd_cache = dict() - dds_cache = dict() - locks = dict() - # Given an array sizes representing reduction size for each - # column of a block-mode matrix multiplication, - # performs load-balancing to achieve more smaller reductions - # between `seg_size` elements - @staticmethod - def load_balance(sizes, block): - # segment size - # heuristics taken from OpenAI blocksparse code - # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95 - max_size = sizes.max() - min_size = sizes[sizes != 0].min() - #if max_size > min_size * 2.0: - # seg_max = max(triton.cdiv(max_size, 4), min_size*2) - #else: - # seg_max = max_size - seg_max = max_size - seg_min = max(triton.cdiv(seg_max, 4), 4) - # split reduction into segments - div = sizes // seg_max - rem = sizes % seg_max - packs = div + (sizes < seg_min).long() + (rem >= seg_min).long() - width = packs.sum() - segments = torch.empty(width, dtype=sizes.dtype) - column = torch.empty_like(segments) - lockid = torch.zeros_like(segments) - maxid = torch.zeros_like(segments) - nlocks = 0 - current = 0 - col_idx = 0 - for i in range(len(sizes)): - d, r = div[i], rem[i] - isempty = sizes[i] < seg_min - last = current + d + (r >= seg_min) + isempty - # column id - column[current:last] = col_idx - # lock id - if d > 1 or (d == 1 and r >= seg_min): - nlocks += 1 - lockid[current:last] = nlocks - maxid[current:last] = last - current - # segment size - segments[current:current+d] = seg_max - if r < seg_min and not isempty: - segments[current+d-1] += r - if r >= seg_min or isempty: - segments[current+d] = r - current = last - col_idx += 1 - offsets = torch.zeros_like(segments) - offsets[1:] = torch.cumsum(segments[:-1], dim=0) - return segments, column, lockid, maxid, offsets - - @staticmethod - def get_locks(size, dev): - if dev not in _matmul.locks or \ - size > _matmul.locks[dev].size(0): - _matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev) - return _matmul.locks[dev] + sdd_cache = dict() + dsd_cache = dict() + dds_cache = dict() + locks = dict() - ########################## - # SPARSE = DENSE x DENSE # - ########################## + # Given an array sizes representing reduction size for each + # column of a block-mode matrix multiplication, + # performs load-balancing to achieve more smaller reductions + # between `seg_size` elements + @staticmethod + def load_balance(sizes, block): + # segment size + # heuristics taken from OpenAI blocksparse code + # https://github.com/openai/blocksparse/blob/master/blocksparse/matmul.py#L95 + max_size = sizes.max() + min_size = sizes[sizes != 0].min() + #if max_size > min_size * 2.0: + # seg_max = max(triton.cdiv(max_size, 4), min_size*2) + #else: + # seg_max = max_size + seg_max = max_size + seg_min = max(triton.cdiv(seg_max, 4), 4) + # split reduction into segments + div = sizes // seg_max + rem = sizes % seg_max + packs = div + (sizes < seg_min).long() + (rem >= seg_min).long() + width = packs.sum() + segments = torch.empty(width, dtype=sizes.dtype) + column = torch.empty_like(segments) + lockid = torch.zeros_like(segments) + maxid = torch.zeros_like(segments) + nlocks = 0 + current = 0 + col_idx = 0 + for i in range(len(sizes)): + d, r = div[i], rem[i] + isempty = sizes[i] < seg_min + last = current + d + (r >= seg_min) + isempty + # column id + column[current:last] = col_idx + # lock id + if d > 1 or (d == 1 and r >= seg_min): + nlocks += 1 + lockid[current:last] = nlocks + maxid[current:last] = last - current + # segment size + segments[current:current + d] = seg_max + if r < seg_min and not isempty: + segments[current + d - 1] += r + if r >= seg_min or isempty: + segments[current + d] = r + current = last + col_idx += 1 + offsets = torch.zeros_like(segments) + offsets[1:] = torch.cumsum(segments[:-1], dim=0) + return segments, column, lockid, maxid, offsets - @staticmethod - def make_sdd_lut(layout, block, dtype, device): - start_width = 128 // block - superblocks = libtriton.superblock(layout.type(torch.int32), start_width) - luts, widths, packs = [], [], [] - for size, nnz in superblocks: - width = nnz.shape[0] // (size*size) - h = nnz[:, 0] - i = nnz[:, 1] - j = nnz[:, 2] - b = nnz[:, 3] - lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous() - luts.append(lut.type(torch.int32).to(device)) - widths.append(width) - packs.append(size) - # create locks - return luts, None, widths, packs + @staticmethod + def get_locks(size, dev): + if dev not in _matmul.locks or \ + size > _matmul.locks[dev].size(0): + _matmul.locks[dev] = torch.zeros(size, dtype=torch.int32, device=dev) + return _matmul.locks[dev] - @staticmethod - def _sdd_matmul(a, b, trans_a, trans_b, trans_c, - spdims, block, luts, num_locks, widths, packs): - - if trans_c: - a, b = b, a - trans_a, trans_b = not trans_b, not trans_a - AS0 = a.size(0) - AS1 = a.size(1) - AS2 = a.size(3 if trans_a else 2) - AS3 = a.size(2 if trans_a else 3) - BS0 = b.size(0) - BS1 = b.size(1) - BS2 = b.size(3 if trans_b else 2) - BS3 = b.size(2 if trans_b else 3) - dtype = a.dtype - device = a.device - is_16_multiple = AS3 % 16 == 0 - is_32_multiple = AS3 % 32 == 0 - is_64_multiple = AS3 % 64 == 0 - if not is_16_multiple: - raise ValueError('Reduction size for SDD must be a multiple of 16') - # create kernel - total_width = sum([width*pack*pack for width,pack in zip(widths, packs)]) - c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device) - for lut, width, pack in zip(luts, widths, packs): - num_lock = 1 - key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) - if key not in _matmul.sdd_cache: - defines = {'TM': block*pack, 'TN': block*pack, - 'TMN': block*block*pack*pack, - 'BLOCK': block, - 'TK': 32, - 'TYPE': dtype, - 'STRIDE_AM': '1' if trans_a else 'lda', - 'STRIDE_AK': 'lda' if trans_a else '1', - 'STRIDE_BN': 'ldb' if trans_b else '1', - 'STRIDE_BK': '1' if trans_b else 'ldb', - 'STRIDE_CM': 'ldc', 'STRIDE_CN': '1', - 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel'} - _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines) + ########################## + # SPARSE = DENSE x DENSE # + ########################## - kernel = _matmul.sdd_cache[key] - # create output - locks = _matmul.get_locks(2*width*AS0*num_lock, a.device) - # maximum grid size is 65535 - # so operation might be decomposed into multiple - # kernel calls - max_width = 49152 - for off_width in range(0, width, max_width): - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), - a.stride(2), b.stride(2), block, - a.stride(0), b.stride(0), c.stride(0), - a.stride(1), b.stride(1), c.stride(0), - AS2, AS2, AS3, off_width, lut.data_ptr(), locks.data_ptr(), num_lock, - grid = lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]) - # save for backward pass - return c + @staticmethod + def make_sdd_lut(layout, block, dtype, device): + start_width = 128 // block + superblocks = libtriton.superblock(layout.type(torch.int32), start_width) + luts, widths, packs = [], [], [] + for size, nnz in superblocks: + width = nnz.shape[0] // (size * size) + h = nnz[:, 0] + i = nnz[:, 1] + j = nnz[:, 2] + b = nnz[:, 3] + lut = torch.stack((h, i, j, b), dim=1).view(-1).contiguous() + luts.append(lut.type(torch.int32).to(device)) + widths.append(width) + packs.append(size) + # create locks + return luts, None, widths, packs - ########################## - # DENSE = DENSE x SPARSE # - # DENSE = SPARSE x DENSE # - ########################## - - # Given a binary layout of 0s and 1s, - # Construct look-up table for efficient execution on GPUs - @staticmethod - def make_dxx_lut(layout, block, step, trans, device, transform = lambda idx: idx): - # load-balancing - _empty = torch.tensor([], dtype=torch.int64, device=layout.device) - segments = _empty.clone() - column = _empty.clone() - depth = _empty.clone() - lockid = _empty.clone() - maxid = _empty.clone() - offsets = _empty.clone() - current_offset = 0 - current_maxid = 0 - for z in range(layout.size(0)): - if trans: - sizes = torch.sum(layout[z, :, :], 1) - else: - sizes = torch.sum(layout[z, :, :], 0) - z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block) - z_depth = z * torch.ones_like(z_segments) - z_lockid[z_lockid > 0] += current_maxid - current_maxid = z_lockid.max() - # concatenate depth - segments = torch.cat((segments, z_segments)) - column = torch.cat((column, z_column)) - depth = torch.cat((depth, z_depth)) - maxid = torch.cat((maxid, z_maxid)) - offsets = torch.cat((offsets, current_offset + z_offsets)) - lockid = torch.cat((lockid, z_lockid)) - current_offset += layout[z, :, :].sum() - segments *= step - # pointer increments - if trans: - nnz = layout.nonzero(as_tuple=False) - else: - nnz = layout.transpose(1, 2).nonzero(as_tuple=False) - num_blocks = nnz.size(0) - offsets = torch.min(offsets, (num_blocks - 1)*torch.ones_like(offsets)) - idx = transform(nnz[:, 2]*block) - xincs = idx.clone() - xincs[1:] -= idx[:-1] - # divide block into multiple steps - div = block // step - xincs = xincs.view(-1, 1).repeat(1, div) - xincs[:, 1:] = step - xincs[:, 0 ] -= (div-1)*step - # first increment for each reduction is actually the offset - xincs[offsets[segments>0], 0] = idx[offsets[segments>0]] - xincs = xincs.view(-1) - # block-mode input increments - if trans: - widx = torch.arange(num_blocks) - else: - widx = _empty.clone() - current_offset = 0 - for z in range(layout.size(0)): - layoutw = layout[z, :, :].clone() - msum = layoutw.sum() - layoutw[layoutw > 0] = 1 + torch.arange(msum) - widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1)) - current_offset += msum - widx = widx - wincs = widx*block*block - wincs[1:] -= widx[:-1]*block*block - wincs = wincs.view(-1, 1).repeat(1, div) - if trans: - wincs[:, 1:] = step - wincs[:, 0] -= (div-1)*step - else: - wincs[:, 1:] = step*block - wincs[:, 0] -= (div - 1)*step*block - wincs[offsets[segments>0], 0] = widx[offsets[segments>0]] - wincs = wincs.view(-1) - # adjust offset and segment size - offsets *= 2*div - segments *= div - # create header - width = column.size(0) - offsets += 6*width - header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous() - incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous() - incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) - # create lut - lut = torch.cat((header, incs)) - lut = lut.type(torch.int32).to(device) - # create locks - num_locks = max(1, lockid.max()) - return lut, num_locks, width, None + @staticmethod + def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs): - @staticmethod - def _dds_matmul(a, b, trans_a, trans_b, trans_c, - spdims, block, lut, num_locks, width, packs): - # shapes / dtypes - AS0 = a.size(0) - AS1 = a.size(1) - AS2 = a.size(3 if trans_a else 2) - AS3 = a.size(2 if trans_a else 3) - BS0 = spdims[0] - BS1 = block * spdims[2 if trans_b else 1] - BS2 = block * spdims[1 if trans_b else 2] - dtype = a.dtype - # kernel - key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) - if key not in _matmul.dds_cache: - defines = {'TM': 128, - 'TN': block, - 'TK': 16, - 'BLOCK': block, - 'TYPE': dtype, - 'STRIDE_AM': 1 if trans_a else 'lda', - 'STRIDE_AK': 'lda' if trans_a else 1, - 'STRIDE_BN': block if trans_b else 1, - 'STRIDE_BK': 1 if trans_b else block, - 'STRIDE_CM': '1' if trans_c else 'ldc', - 'STRIDE_CN': 'ldc' if trans_c else '1', - 'NAME': 'dds_kernel', - 'DDS': True} - _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) - kernel = _matmul.dds_cache[key] - # output - CS0 = AS0 - CS1 = AS1 - CS2 = BS2 if trans_c else AS2 - CS3 = AS2 if trans_c else BS2 - locks = _matmul.get_locks(2*AS0*AS2//32*num_locks, a.device) - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), - a.stride(2), block, c.stride(2), - a.stride(0), b.stride(0), c.stride(0), - a.stride(1), b.stride(1), c.stride(1), - AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks, - grid = lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]) - return c - - @staticmethod - def _dsd_matmul(a, b, trans_a, trans_b, trans_c, - spdims, block, lut, num_locks, width, packs): - # shapes / dtypes - AS0 = spdims[0] - AS1 = block * spdims[2 if trans_a else 1] - AS2 = block * spdims[1 if trans_a else 2] - BS0 = b.size(0) - BS1 = b.size(1) - BS2 = b.size(3 if trans_b else 2) - BS3 = b.size(2 if trans_b else 3) - dtype = a.dtype - # kernel - key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) - if key not in _matmul.dsd_cache: - defines = {'TM': block, - 'TN': 128, - 'TK': 16, - 'BLOCK': block, - 'TYPE': dtype, - 'STRIDE_AM': 1 if trans_a else block, - 'STRIDE_AK': block if trans_a else 1, - 'STRIDE_BN': 'ldb' if trans_b else '1', - 'STRIDE_BK': '1' if trans_b else 'ldb', - 'STRIDE_CM': '1' if trans_c else 'ldc', - 'STRIDE_CN': 'ldc' if trans_c else '1', - 'NAME': 'dsd_kernel', - 'DSD': True} - _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines) - kernel = _matmul.dsd_cache[key] - # output - CS0 = BS0 - CS1 = BS1 - CS2 = BS3 if trans_c else AS1 - CS3 = AS1 if trans_c else BS3 - locks = _matmul.get_locks(2*BS0*BS3//32*num_locks, a.device) - c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) - kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), - block, b.stride(2), c.stride(2), - a.stride(0), b.stride(0), c.stride(0), - a.stride(1), b.stride(1), c.stride(1), - BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), num_locks, - grid = lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]) - return c + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + AS0 = a.size(0) + AS1 = a.size(1) + AS2 = a.size(3 if trans_a else 2) + AS3 = a.size(2 if trans_a else 3) + BS0 = b.size(0) + BS1 = b.size(1) + BS2 = b.size(3 if trans_b else 2) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + device = a.device + is_16_multiple = AS3 % 16 == 0 + is_32_multiple = AS3 % 32 == 0 + is_64_multiple = AS3 % 64 == 0 + if not is_16_multiple: + raise ValueError('Reduction size for SDD must be a multiple of 16') + # create kernel + total_width = sum([width * pack * pack for width, pack in zip(widths, packs)]) + c = torch.empty((AS0, total_width, block, block), dtype=dtype, device=device) + for lut, width, pack in zip(luts, widths, packs): + num_lock = 1 + key = (block, device, a.dtype, b.dtype, trans_a, trans_b, trans_c, pack, is_32_multiple, is_64_multiple) + if key not in _matmul.sdd_cache: + defines = { + 'TM': block * pack, 'TN': block * pack, 'TMN': block * block * pack * pack, 'BLOCK': block, 'TK': + 32, 'TYPE': dtype, 'STRIDE_AM': '1' if trans_a else 'lda', 'STRIDE_AK': 'lda' if trans_a else '1', + 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': '1' if trans_b else 'ldb', 'STRIDE_CM': 'ldc', + 'STRIDE_CN': '1', 'SDD': True, 'TZ': 1, 'NAME': 'sdd_kernel' + } + _matmul.sdd_cache[key] = triton.kernel(src, device=device, defines=defines) - fn = {'sdd': _sdd_matmul.__get__(object), - 'dsd': _dsd_matmul.__get__(object), - 'dds': _dds_matmul.__get__(object)} + kernel = _matmul.sdd_cache[key] + # create output + locks = _matmul.get_locks(2 * width * AS0 * num_lock, a.device) + # maximum grid size is 65535 + # so operation might be decomposed into multiple + # kernel calls + max_width = 49152 + for off_width in range(0, width, max_width): + kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), b.stride(2), block, a.stride(0), + b.stride(0), c.stride(0), a.stride(1), b.stride(1), c.stride(0), AS2, AS2, AS3, off_width, + lut.data_ptr(), locks.data_ptr(), num_lock, + grid=lambda opt: [opt.TZ, min(max_width, width - off_width), AS0]) + # save for backward pass + return c - @staticmethod - def forward(ctx, a, b, trans_a, trans_b, trans_c, - mode, spdims, block, - c_lut, c_num_locks, c_width, c_packs, - da_lut, da_num_locks, da_width, da_packs, - db_lut, db_num_locks, db_width, db_packs): - c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, - c_lut, c_num_locks, c_width, c_packs) - # save for backward - ctx.save_for_backward(a, b) - ctx.da_num_locks = da_num_locks - ctx.da_lut = da_lut - ctx.da_width = da_width - ctx.da_packs = da_packs - ctx.db_lut = db_lut - ctx.db_num_locks = db_num_locks - ctx.db_width = db_width - ctx.db_packs = db_packs - ctx.mode = mode - ctx.spdims = spdims - ctx.block = block - ctx.trans_a = trans_a - ctx.trans_b = trans_b - return c + ########################## + # DENSE = DENSE x SPARSE # + # DENSE = SPARSE x DENSE # + ########################## - @staticmethod - def backward(ctx, dc): - # saved for backward - a, b = ctx.saved_tensors - mode = ctx.mode - # gradients w.r.t. a - if ctx.needs_input_grad[0]: - mode_da = mode[1] + mode[0] + mode[2] - da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, - ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs) - # gradients w.r.t. b - if ctx.needs_input_grad[1]: - mode_db = mode[2] + mode[1] + mode[0] - db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, - ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs) - return da, db, None, None, None,\ - None, None, None, None,\ - None, None, None, None, None, None,\ - None, None, None, None, None, None,\ - None, None, None, None, None, None + # Given a binary layout of 0s and 1s, + # Construct look-up table for efficient execution on GPUs + @staticmethod + def make_dxx_lut(layout, block, step, trans, device, transform=lambda idx: idx): + # load-balancing + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + segments = _empty.clone() + column = _empty.clone() + depth = _empty.clone() + lockid = _empty.clone() + maxid = _empty.clone() + offsets = _empty.clone() + current_offset = 0 + current_maxid = 0 + for z in range(layout.size(0)): + if trans: + sizes = torch.sum(layout[z, :, :], 1) + else: + sizes = torch.sum(layout[z, :, :], 0) + z_segments, z_column, z_lockid, z_maxid, z_offsets = _matmul.load_balance(sizes, block) + z_depth = z * torch.ones_like(z_segments) + z_lockid[z_lockid > 0] += current_maxid + current_maxid = z_lockid.max() + # concatenate depth + segments = torch.cat((segments, z_segments)) + column = torch.cat((column, z_column)) + depth = torch.cat((depth, z_depth)) + maxid = torch.cat((maxid, z_maxid)) + offsets = torch.cat((offsets, current_offset + z_offsets)) + lockid = torch.cat((lockid, z_lockid)) + current_offset += layout[z, :, :].sum() + segments *= step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) + else: + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) + idx = transform(nnz[:, 2] * block) + xincs = idx.clone() + xincs[1:] -= idx[:-1] + # divide block into multiple steps + div = block // step + xincs = xincs.view(-1, 1).repeat(1, div) + xincs[:, 1:] = step + xincs[:, 0] -= (div - 1) * step + # first increment for each reduction is actually the offset + xincs[offsets[segments > 0], 0] = idx[offsets[segments > 0]] + xincs = xincs.view(-1) + # block-mode input increments + if trans: + widx = torch.arange(num_blocks) + else: + widx = _empty.clone() + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum) + widx = torch.cat((widx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + widx = widx + wincs = widx * block * block + wincs[1:] -= widx[:-1] * block * block + wincs = wincs.view(-1, 1).repeat(1, div) + if trans: + wincs[:, 1:] = step + wincs[:, 0] -= (div - 1) * step + else: + wincs[:, 1:] = step * block + wincs[:, 0] -= (div - 1) * step * block + wincs[offsets[segments > 0], 0] = widx[offsets[segments > 0]] + wincs = wincs.view(-1) + # adjust offset and segment size + offsets *= 2 * div + segments *= div + # create header + width = column.size(0) + offsets += 6 * width + header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous() + incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous() + incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + num_locks = max(1, lockid.max()) + return lut, num_locks, width, None + + @staticmethod + def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): + # shapes / dtypes + AS0 = a.size(0) + AS1 = a.size(1) + AS2 = a.size(3 if trans_a else 2) + AS3 = a.size(2 if trans_a else 3) + BS0 = spdims[0] + BS1 = block * spdims[2 if trans_b else 1] + BS2 = block * spdims[1 if trans_b else 2] + dtype = a.dtype + # kernel + key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) + if key not in _matmul.dds_cache: + defines = { + 'TM': 128, 'TN': block, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else 'lda', + 'STRIDE_AK': 'lda' if trans_a else 1, 'STRIDE_BN': block if trans_b else 1, 'STRIDE_BK': + 1 if trans_b else block, 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': 'ldc' if trans_c else '1', + 'NAME': 'dds_kernel', 'DDS': True + } + _matmul.dds_cache[key] = triton.kernel(src, device=a.device, defines=defines) + kernel = _matmul.dds_cache[key] + # output + CS0 = AS0 + CS1 = AS1 + CS2 = BS2 if trans_c else AS2 + CS3 = AS2 if trans_c else BS2 + locks = _matmul.get_locks(2 * AS0 * AS2 // 32 * num_locks, a.device) + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), a.stride(2), block, c.stride(2), a.stride(0), b.stride(0), + c.stride(0), a.stride(1), b.stride(1), c.stride(1), AS2, BS2, 0, 0, lut.data_ptr(), locks.data_ptr(), + num_locks, grid=lambda opt: [width, triton.cdiv(AS2, opt.TM), AS0]) + return c + + @staticmethod + def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs): + # shapes / dtypes + AS0 = spdims[0] + AS1 = block * spdims[2 if trans_a else 1] + AS2 = block * spdims[1 if trans_a else 2] + BS0 = b.size(0) + BS1 = b.size(1) + BS2 = b.size(3 if trans_b else 2) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # kernel + key = (block, a.device, a.dtype, b.dtype, trans_a, trans_b, trans_c) + if key not in _matmul.dsd_cache: + defines = { + 'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TYPE': dtype, 'STRIDE_AM': 1 if trans_a else block, + 'STRIDE_AK': block if trans_a else 1, 'STRIDE_BN': 'ldb' if trans_b else '1', 'STRIDE_BK': + '1' if trans_b else 'ldb', 'STRIDE_CM': '1' if trans_c else 'ldc', 'STRIDE_CN': + 'ldc' if trans_c else '1', 'NAME': 'dsd_kernel', 'DSD': True + } + _matmul.dsd_cache[key] = triton.kernel(src, device=a.device, defines=defines) + kernel = _matmul.dsd_cache[key] + # output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + locks = _matmul.get_locks(2 * BS0 * BS3 // 32 * num_locks, a.device) + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), block, b.stride(2), c.stride(2), a.stride(0), b.stride(0), + c.stride(0), a.stride(1), b.stride(1), c.stride(1), BS3, AS1, 0, 0, lut.data_ptr(), locks.data_ptr(), + num_locks, grid=lambda opt: [width, triton.cdiv(BS3, opt.TN), BS0]) + return c + + fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)} + + @staticmethod + def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs, da_lut, + da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs): + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width, c_packs) + # save for backward + ctx.save_for_backward(a, b) + ctx.da_num_locks = da_num_locks + ctx.da_lut = da_lut + ctx.da_width = da_width + ctx.da_packs = da_packs + ctx.db_lut = db_lut + ctx.db_num_locks = db_num_locks + ctx.db_width = db_width + ctx.db_packs = db_packs + ctx.mode = mode + ctx.spdims = spdims + ctx.block = block + ctx.trans_a = trans_a + ctx.trans_b = trans_b + return c + + @staticmethod + def backward(ctx, dc): + # saved for backward + a, b = ctx.saved_tensors + mode = ctx.mode + # gradients w.r.t. a + if ctx.needs_input_grad[0]: + mode_da = mode[1] + mode[0] + mode[2] + da = _matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, + ctx.da_num_locks, ctx.da_width, ctx.da_packs) + # gradients w.r.t. b + if ctx.needs_input_grad[1]: + mode_db = mode[2] + mode[1] + mode[0] + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, + ctx.db_num_locks, ctx.db_width, ctx.db_packs) + return da, db, None, None, None,\ + None, None, None, None,\ + None, None, None, None, None, None,\ + None, None, None, None, None, None,\ + None, None, None, None, None, None class matmul: - - def make_lut(self, dtype, device): - key = (dtype, device) - if key in self.lut_cache: - return self.lut_cache[key] - # C look-up table - layout, block = self.layout, self.block - step = 8 if dtype == torch.float32 else 16 - if self.mode == 'sdd': - c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device) - elif self.mode == 'dsd': - c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) - elif self.mode == 'dds': - c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device) - # DA look-up table - if self.mode == 'sdd': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device) - elif self.mode == 'dsd': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device) - elif self.mode == 'dds': - da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) - # DB look-up table - if self.mode == 'sdd': - db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) - elif self.mode == 'dsd': - db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) - elif self.mode == 'dds': - db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device) - self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\ - da_lut, da_num_locks, da_width, da_packs,\ - db_lut, db_num_locks, db_width, db_packs) - return self.lut_cache[key] + def make_lut(self, dtype, device): + key = (dtype, device) + if key in self.lut_cache: + return self.lut_cache[key] + # C look-up table + layout, block = self.layout, self.block + step = 8 if dtype == torch.float32 else 16 + if self.mode == 'sdd': + c_lut, c_num_locks, c_width, c_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + elif self.mode == 'dsd': + c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) + elif self.mode == 'dds': + c_lut, c_num_locks, c_width, c_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_b, device) + # DA look-up table + if self.mode == 'sdd': + da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, True, device) + elif self.mode == 'dsd': + da_lut, da_num_locks, da_width, da_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + elif self.mode == 'dds': + da_lut, da_num_locks, da_width, da_packs = _matmul.make_dxx_lut(layout, block, step, not self.trans_b, + device) + # DB look-up table + if self.mode == 'sdd': + db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, False, device) + elif self.mode == 'dsd': + db_lut, db_num_locks, db_width, db_packs = _matmul.make_dxx_lut(layout, block, step, self.trans_a, device) + elif self.mode == 'dds': + db_lut, db_num_locks, db_width, db_packs = _matmul.make_sdd_lut(layout, block, dtype, device) + self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\ + da_lut, da_num_locks, da_width, da_packs,\ + db_lut, db_num_locks, db_width, db_packs) + return self.lut_cache[key] - def __init__(self, layout, block, mode, trans_a = False, trans_b = False): - if mode not in ['sdd', 'dsd', 'dds']: - raise NotImplementedError('Supported modes are: sdd, dsd, dds') - # look-up table cache - self.lut_cache = dict() - # attributes - self.trans_a = trans_a - self.trans_b = trans_b - self.mode = mode - self.spdims = layout.shape - self.block = block - self.layout = layout - - # pad shapes of a tensor to make it - # compatible with kernel calls - @staticmethod - def _pad_shape(x, is_sparse): - max_dim = 3 if is_sparse else 4 - for i in range(max_dim - x.dim()): - x = x.unsqueeze(0) - return x + def __init__(self, layout, block, mode, trans_a=False, trans_b=False): + if mode not in ['sdd', 'dsd', 'dds']: + raise NotImplementedError('Supported modes are: sdd, dsd, dds') + # look-up table cache + self.lut_cache = dict() + # attributes + self.trans_a = trans_a + self.trans_b = trans_b + self.mode = mode + self.spdims = layout.shape + self.block = block + self.layout = layout - def __call__(self, a, b): - c_lut, c_num_locks, c_width, c_packs,\ - da_lut, da_num_locks, da_width, da_packs,\ - db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) - # pad shapes with ones - a = matmul._pad_shape(a, self.mode == 'dsd') - b = matmul._pad_shape(b, self.mode == 'dds') - # execute - c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, - self.mode, self.spdims, self.block, - c_lut, c_num_locks, c_width, c_packs, - da_lut, da_num_locks, da_width, da_packs, - db_lut, db_num_locks, db_width, db_packs) - return c \ No newline at end of file + # pad shapes of a tensor to make it + # compatible with kernel calls + @staticmethod + def _pad_shape(x, is_sparse): + max_dim = 3 if is_sparse else 4 + for i in range(max_dim - x.dim()): + x = x.unsqueeze(0) + return x + + def __call__(self, a, b): + c_lut, c_num_locks, c_width, c_packs,\ + da_lut, da_num_locks, da_width, da_packs,\ + db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device) + # pad shapes with ones + a = matmul._pad_shape(a, self.mode == 'dsd') + b = matmul._pad_shape(b, self.mode == 'dds') + # execute + c = _matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, + c_num_locks, c_width, c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, + db_width, db_packs) + return c diff --git a/python/triton/ops/matmul.c b/python/triton/ops/matmul.c index 2875e54e2..3410649c7 100644 --- a/python/triton/ops/matmul.c +++ b/python/triton/ops/matmul.c @@ -1,9 +1,9 @@ #define STM 8 #define STN 8 -__global__ void matmul(TYPE * A __noalias __readonly __aligned(16), - TYPE * B __noalias __readonly __aligned(16), - TYPE * C __noalias __aligned(16), +__global__ void matmul(TYPE *A __noalias __readonly __aligned(16), + TYPE *B __noalias __readonly __aligned(16), + TYPE *C __noalias __aligned(16), float alpha, int M, int N, @@ -11,87 +11,88 @@ __global__ void matmul(TYPE * A __noalias __readonly __aligned(16), int lda __multipleof(LDA_POW2_DIV), int ldb __multipleof(LDB_POW2_DIV), int ldc __multipleof(LDC_POW2_DIV), - int* locks) { - // prologue - int pid = get_program_id(0); - int pidz = get_program_id(2); - int gridm = (M + TM - 1) / TM; - int gridn = (N + TN - 1) / TN; + int *locks) { + // prologue + int pid = get_program_id(0); + int pidz = get_program_id(2); + int gridm = (M + TM - 1) / TM; + int gridn = (N + TN - 1) / TN; - // swizzle for better L2 performance - int width = STM*gridn; - int stm = pid / width; - int RSTM = min(gridm - stm*STM, STM); - int stn = (pid % width) / (RSTM*STN); - int RSTN = min(gridn - stn*STN, STN); - int laneid = pid % (RSTM * RSTN); - int lanem = laneid / RSTN; - int lanen = laneid % RSTN; - int pidm = stm*STM + lanem; - int pidn = stn*STN + lanen; - int rm[TM] = pidm * TM + 0 ... TM; - int rn[TN] = pidn * TN + 0 ... TN; + // swizzle for better L2 performance + int width = STM * gridn; + int stm = pid / width; + int RSTM = min(gridm - stm * STM, STM); + int stn = (pid % width) / (RSTM * STN); + int RSTN = min(gridn - stn * STN, STN); + int laneid = pid % (RSTM * RSTN); + int lanem = laneid / RSTN; + int lanen = laneid % RSTN; + int pidm = stm * STM + lanem; + int pidn = stn * STN + lanen; + int rm[TM] = pidm * TM + 0 ... TM; + int rn[TN] = pidn * TN + 0 ... TN; - // split-k for better parrallelism - K = K / TZ; - int rk[TK] = 0 ... TK; - // pointers to operands - int offa[TM, TK] = (pidz*K + rk[newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; - int offb[TK, TN] = (pidz*K + rk[:, newaxis]) * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; - TYPE* pa[TM, TK] = A + offa; - TYPE* pb[TK, TN] = B + offb; + // split-k for better parrallelism + K = K / SPLITK; + int rk[TK] = 0 ... TK; + // pointers to operands + int offa[TM, TK] = (pidz * K + rk [newaxis, :]) * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; + int offb[TK, TN] = (pidz * K + rk[:, newaxis]) * STRIDE_BK + rn [newaxis, :] * STRIDE_BN; + TYPE *pa[TM, TK] = A + offa; + TYPE *pb[TK, TN] = B + offb; - // prefetches operands - bool checka[TM, TK] = rk[newaxis, :] < K; - bool checkb[TK, TN] = rk[:, newaxis] < K; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; + // prefetches operands + bool checka[TM, TK] = rk [newaxis, :] < K; + bool checkb[TK, TN] = rk[:, newaxis] < K; + TYPE a[TM, TK] = checka ? *pa : 0; + TYPE b[TK, TN] = checkb ? *pb : 0; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; - // reduction loop - float acc[TM, TN] = 0; - for(int k = K; k > 0; k -= TK){ -#if (IS_TK_DIV_K==1) - bool checkk[TK] = k > TK; + // reduction loop + float acc[TM, TN] = 0; + for (int k = K; k > 0; k -= TK) { +#if (IS_TK_DIV_K == 1) + bool checkk[TK] = k > TK; #else - bool checkk[TK] = rk < k - TK; + bool checkk[TK] = rk < k - TK; #endif - bool checka[TM, TK] = checkk[newaxis, :]; - bool checkb[TK, TN] = checkk[:, newaxis]; - acc += a @ b; -#if (IS_TK_DIV_K==1) - a = *?(checka)pa; - b = *?(checkb)pb; + bool checka[TM, TK] = checkk [newaxis, :]; + bool checkb[TK, TN] = checkk[:, newaxis]; + acc += a @b; +#if (IS_TK_DIV_K == 1) + a = *? (checka)pa; + b = *? (checkb)pb; #else - a = checka ? *pa : 0; - b = checkb ? *pb : 0; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; #endif - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; + } + acc = acc * alpha; + TYPE c[TM, TN] = acc; - // epilogue - int rcm[TM] = pidm * TM + 0 ... TM; - int rcn[TN] = pidn * TN + 0 ... TN; - int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :]; - TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn[newaxis, :] < N; -#if (TZ==1) - *?(checkc) pc = c; + // epilogue + int rcm[TM] = pidm * TM + 0 ... TM; + int rcn[TN] = pidn * TN + 0 ... TN; + int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn [newaxis, :]; + TYPE *pc[TM, TN] = C + offc; + bool checkc[TM, TN] = rcm[:, newaxis] < M && rcn [newaxis, :] < N; +#if (SPLITK == 1) + *? (checkc)pc = c; #else - // accumulate partial result using spin-locks - int *plock = locks + pid; - int *pcount = plock + get_num_programs(0); - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % TZ); - atomic_xchg(plock, 0); + // accumulate partial result using spin-locks + int *plock = locks + pid; + int *pcount = plock + get_num_programs(0); + for (int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)) + ; + int count = *pcount; + if (count == 0) + *? (checkc)pc = c; + else + *? (checkc)pc = c + *? (checkc)pc; + atomic_xchg(pcount, (count + 1) % SPLITK); + atomic_xchg(plock, 0); #endif } \ No newline at end of file diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index ce8b908b4..6033ca780 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -3,29 +3,32 @@ import triton import os class _matmul(torch.autograd.Function): - src = triton.read(os.path.join(os.path.dirname(__file__), 'matmul.c')) + src = triton.read(os.path.join(os.path.dirname(__file__), "matmul.c")) _DEFAULT_CONFIGS = [ - ({'TM': '128', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), - ({'TM': '64', 'TN': '128', 'TK': '32', 'TZ': '1'}, 4), - ({'TM': '128', 'TN': '64', 'TK': '32', 'TZ': '1'}, 4), - ({'TM': '64', 'TN': '64', 'TK': '64', 'TZ': '1'}, 4), - ({'TM': '32', 'TN': '128', 'TK': '64', 'TZ': '1'}, 4), - ({'TM': '128', 'TN': '32', 'TK': '64', 'TZ': '1'}, 4), - ({'TM': '64', 'TN': '32', 'TK': '64', 'TZ': '1'}, 2), - ({'TM': '32', 'TN': '64', 'TK': '64', 'TZ': '1'}, 2), - ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), - ({'TM': '32', 'TN': '128', 'TK': '32', 'TZ': '2'}, 4), - ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), - ({'TM': '128', 'TN': '32', 'TK': '32', 'TZ': '4'}, 4), + ({"TM": "128", "TN": "128", "TK": "32", "SPLITK": "1"}, 4), + ({'TM': '64', 'TN': '128', 'TK': '32', 'SPLITK': '1'}, 4), + ({'TM': '128', 'TN': '64', 'TK': '32', 'SPLITK': '1'}, 4), + ({'TM': '64', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 4), + ({'TM': '32', 'TN': '128', 'TK': '64', 'SPLITK': '1'}, 4), + ({'TM': '128', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 4), + ({'TM': '64', 'TN': '32', 'TK': '64', 'SPLITK': '1'}, 2), + ({'TM': '32', 'TN': '64', 'TK': '64', 'SPLITK': '1'}, 2), + # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4), + # ({'TM': '32', 'TN': '128', 'TK': '32', 'SPLITK': '2'}, 4), + # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4), + # ({'TM': '128', 'TN': '32', 'TK': '32', 'SPLITK': '4'}, 4), ] _CONFIGS = _DEFAULT_CONFIGS @staticmethod def largest_pow2_divisor(N): - if N % 8 == 0: return 8 - if N % 4 == 0: return 4 - if N % 2 == 0: return 2 + if N % 8 == 0: + return 8 + if N % 4 == 0: + return 4 + if N % 2 == 0: + return 2 return 1 _locks = dict() @@ -40,8 +43,10 @@ class _matmul(torch.autograd.Function): K, N = b.shape c = torch.empty((M, N), dtype=dtype, device=device) # handle non-contiguous inputs if necessary - if a.stride(0) > 1 and a.stride(1) > 1: a = a.contiguous() - if b.stride(0) > 1 and b.stride(1) > 1: b = b.contiguous() + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() # kernel hash is_a_row = a.stride(1) == 1 is_b_row = b.stride(1) == 1 @@ -52,28 +57,60 @@ class _matmul(torch.autograd.Function): ldb_pow2_div = _matmul.largest_pow2_divisor(ldb) ldc_pow2_div = _matmul.largest_pow2_divisor(ldc) is_tk_div_k = K % 64 == 0 - key = (device, dtype, is_a_row, is_b_row, lda_pow2_div, ldb_pow2_div, ldc_pow2_div, is_tk_div_k) + key = ( + device, + dtype, + is_a_row, + is_b_row, + lda_pow2_div, + ldb_pow2_div, + ldc_pow2_div, + is_tk_div_k, + ) if key not in _matmul._kernels: defines = { - 'TYPE': dtype, 'STRIDE_AM': 'lda' if is_a_row else '1', 'STRIDE_AK': '1' if is_a_row else 'lda', - 'STRIDE_BK': 'ldb' if is_b_row else '1', 'STRIDE_BN': '1' if is_b_row else 'ldb', 'LDA_POW2_DIV': - lda_pow2_div, 'LDB_POW2_DIV': ldb_pow2_div, 'LDC_POW2_DIV': ldc_pow2_div, 'IS_TK_DIV_K': - int(is_tk_div_k) + "TYPE": dtype, + "STRIDE_AM": "lda" if is_a_row else "1", + "STRIDE_AK": "1" if is_a_row else "lda", + "STRIDE_BK": "ldb" if is_b_row else "1", + "STRIDE_BN": "1" if is_b_row else "ldb", + "LDA_POW2_DIV": lda_pow2_div, + "LDB_POW2_DIV": ldb_pow2_div, + "LDC_POW2_DIV": ldc_pow2_div, + "IS_TK_DIV_K": int(is_tk_div_k), } - _matmul._kernels[key] = triton.kernel(_matmul.src, - device, - defines=defines, - autotune_vals=_matmul._CONFIGS, - autotune_key=['M', 'N', 'K']) + _matmul._kernels[key] = triton.kernel( + _matmul.src, + device, + defines=defines, + autotune_vals=_matmul._CONFIGS, + autotune_key=["M", "N", "K"], + ) kernel = _matmul._kernels[key] # # locks for split-k if device not in _matmul._locks: _matmul._locks[device] = torch.zeros(1024 * 1024, dtype=torch.int32, device=device) locks = _matmul._locks[device] # enqueue - alpha = 1. - args = [a.data_ptr(), b.data_ptr(), c.data_ptr(), alpha, M, N, K, lda, ldb, ldc, locks.data_ptr()] - grid = lambda opt: [triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), 1, opt.TZ] + alpha = 1.0 + args = [ + a.data_ptr(), + b.data_ptr(), + c.data_ptr(), + alpha, + M, + N, + K, + lda, + ldb, + ldc, + locks.data_ptr(), + ] + grid = lambda opt: [ + triton.cdiv(M, opt.TM) * triton.cdiv(N, opt.TN), + 1, + opt.SPLITK, + ] kernel(*args, grid=grid) return c diff --git a/python/triton/testing.py b/python/triton/testing.py index 27812df56..2415e8f43 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,21 +1,33 @@ import torch + def sparsify_tensor(x, mask, block): - ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device) + ret = torch.empty( + (x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device + ) for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): - ret[:, idx, :, :] = x[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] + ret[:, idx, :, :] = x[ + :, h, i * block : (i + 1) * block, j * block : (j + 1) * block + ] return ret + def mask_tensor(x, mask, block, value=0): ret = x.clone() for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): - ret[:, h, i * block:(i + 1) * block, j * block:(j + 1) * block] = value + ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value return ret + def allclose(x, y): assert x.dtype == y.dtype - rtol, atol = {torch.float32: (1e-4, 1e-5), torch.float16: (1e-2, 1e-3)}[x.dtype] - return torch.allclose(x, y, atol=atol, rtol=rtol) + diff = abs(x - y) + x_max = torch.max(x) + y_max = torch.max(y) + tol = 1e-2 + err = torch.max(diff) / torch.max(x_max, y_max) + return err < tol + def do_bench(fn, flops=0, warmup=10, rep=50): start_event = torch.cuda.Event(enable_timing=True) @@ -32,8 +44,11 @@ def do_bench(fn, flops=0, warmup=10, rep=50): time_ms = start_event.elapsed_time(end_event) / rep return time_ms + class Benchmark: - def __init__(self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args): + def __init__( + self, x_names, x_vals, y_name, y_vals, y_lines, ylabel, loglog, plot_name, args + ): self.x_names = x_names self.x_vals = x_vals self.y_name = y_name @@ -44,6 +59,7 @@ class Benchmark: self.plot_name = plot_name self.args = args + class Mark: def __init__(self, fn, benchmarks): self.fn = fn @@ -53,26 +69,31 @@ class Mark: import matplotlib.pyplot as plt import pandas as pd import os + df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines) for x in bench.x_vals: x_args = {x_name: x for x_name in bench.x_names} - row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals] + row = [ + self.fn(**x_args, **{bench.y_name: y}, **bench.args) + for y in bench.y_vals + ] df.loc[len(df)] = [x] + row if with_plot and bench.plot_name: - xlabel = ' = '.join(bench.x_names) + xlabel = " = ".join(bench.x_names) plot = df.plot(x=bench.x_names[0], y=bench.y_lines) plot.set_xlabel(xlabel) plot.set_ylabel(bench.ylabel) plot.set_title(bench.plot_name) - plot.set_xscale('log' if bench.loglog else 'linear') - plot.set_yscale('log' if bench.loglog else 'linear') - plt.savefig(os.path.join(result_path, f'{bench.plot_name}.png')) - df.to_csv(os.path.join(result_path, f'{bench.plot_name}.csv')) + plot.set_xscale("log" if bench.loglog else "linear") + plot.set_yscale("log" if bench.loglog else "linear") + plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png")) + df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv")) def run(self, result_path, with_plot): for bench in self.benchmarks: self._run(bench, result_path, with_plot) + def perf_report(benchmarks): wrapper = lambda fn: Mark(fn, benchmarks) return wrapper diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc index 1188b222d..53817e9c8 100644 --- a/tutorials/01-matmul.cc +++ b/tutorials/01-matmul.cc @@ -66,18 +66,19 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), bool checkb[TK, TN] = rk[:, newaxis] < K; TYPE a[TM, TK] = checka ? *pa : 0; TYPE b[TK, TN] = checkb ? *pb : 0; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; // reduction loop float acc[TM, TN] = 0; for(int k = K; k > 0; k -= TK){ bool checka[TM, TK] = k > TK; bool checkb[TK, TN] = k > TK; - acc += a @ b; - a = *?(checka)pa; - b = *?(checkb)pb; pa += TK * STRIDE_AK; pb += TK * STRIDE_BK; + TYPE anext[TM, TK] = *?(checka)pa; + TYPE bnext[TK, TN] = *?(checkb)pb; + acc += a @ b; + a = anext; + b = bnext; +// __debug_barrier(); } acc = acc * alpha; TYPE c[TM, TN] = acc; @@ -166,7 +167,7 @@ float triton_dot(drv::context* context, drv::stream* stream, opt.defines["TYPE"] = ty; opt.defines["TM"] = "128"; opt.defines["TN"] = "128"; - opt.defines["TK"] = "32" ; + opt.defines["TK"] = "64" ; opt.defines["TZ"] = "1"; opt.num_warps = 4; // arguments