diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 99481f694..a69687875 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -258,7 +258,8 @@ public: const std::vector& shapes, const std::vector &values_, ir::type *ty, - analysis::align* align, target *tgt); + analysis::align* align, target *tgt, + bool is_tmp = false); void accept(layout_visitor* vst) { vst->visit_layout_shared(this); } // accessors size_t get_size() { return size_; } @@ -276,6 +277,7 @@ public: int get_mma_strided() { return mma_strided_; } bool allow_swizzle() const { return allow_swizzle_; } data_layout* get_arg_layout() { return arg_layout_; } + bool is_tmp() const { return is_tmp_; } private: size_t size_; @@ -290,6 +292,7 @@ private: int mma_strided_; bool allow_swizzle_ = true; target *tgt_; + bool is_tmp_; }; diff --git a/include/triton/codegen/transform/coalesce.h b/include/triton/codegen/transform/coalesce.h index 869ca9975..e16ffe5fe 100644 --- a/include/triton/codegen/transform/coalesce.h +++ b/include/triton/codegen/transform/coalesce.h @@ -32,11 +32,12 @@ private: ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map& seen); public: - coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts); + coalesce(analysis::align* align, triton::codegen::analysis::layouts *layouts, bool has_sm80); triton::ir::value *simplify(ir::instruction* i, triton::ir::builder &builder); void run(ir::module &mod); private: + bool has_sm80_; analysis::align* align_; analysis::layouts* layout_; }; diff --git a/include/triton/codegen/transform/cts.h b/include/triton/codegen/transform/cts.h index 70fbc474b..30b421b52 100644 --- a/include/triton/codegen/transform/cts.h +++ b/include/triton/codegen/transform/cts.h @@ -15,18 +15,26 @@ namespace ir { } namespace codegen{ + +namespace analysis{ +class layouts; +} + namespace transform{ class cts { private: - void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared); + bool is_shmem_op(ir::instruction* i, int op); + bool is_shmem_res(ir::value* i); +void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map& copies); public: - cts(bool use_async = false): use_async_(use_async) {} + cts(analysis::layouts* layouts, bool has_sm80 = false): layouts_(layouts), has_sm80_(has_sm80) {} void run(ir::module &mod); private: - bool use_async_; + bool has_sm80_; + analysis::layouts* layouts_; }; } diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 0cb622679..74028f822 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -142,9 +142,9 @@ public: value *create_or(value *lhs, value *rhs); // Input/Output value *create_load(value *arg, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); - value *create_store(value *ptr, value *val); + value *create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction); value *create_masked_load(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile); - value *create_masked_store(value *ptr, value *val, value *mask); + value *create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction); // Struct instructions value *create_insert_value(value* val, value *elt, size_t idx); value *create_extract_value(value* val, size_t idx); @@ -176,7 +176,7 @@ public: value *create_cos(value* arg); value *create_sin(value* arg); value *create_log(value* arg); - value *create_dot(value *A, value *B, value *C, bool allow_tf32); + value *create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32); value *create_trans(value *A, const std::vector &perm = {}); value *create_sqrt(value *A); value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 4e76e60a4..61ec2a6ae 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -112,7 +112,7 @@ public: static function *create(function_type *ty, linkage_types_t linkage, const std::string &name, module *mod); // blocks - const blocks_t &blocks() { return blocks_; } + blocks_t &blocks() { return blocks_; } const blocks_t &blocks() const { return blocks_; } void insert_block(basic_block* block, basic_block *next = nullptr); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 734ea2b42..402208a8b 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -435,13 +435,31 @@ private: //===----------------------------------------------------------------------===// class io_inst: public instruction { +public: + + enum EVICTION_POLICY : uint32_t { + NORMAL=0, + EVICT_FIRST, + EVICT_LAST, + }; + protected: - io_inst(type *ty, value_id_t id, unsigned num_ops, + io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); + std::string get_eviction_policy_repr() const { + if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; + if (eviction_ == EVICT_LAST) return ".L2::evict_last"; + return ""; + } + public: // accessors value *get_pointer_operand() { return get_operand(0); } + EVICTION_POLICY get_eviction_policy() const { return eviction_; } + +protected: + EVICTION_POLICY eviction_; }; // load @@ -453,14 +471,8 @@ public: CG, }; - enum EVICTION_POLICY : uint32_t { - NORMAL=0, - EVICT_FIRST, - EVICT_LAST, - }; CACHE_MODIFIER get_cache_modifier() const { return cache_; } - EVICTION_POLICY get_eviction_policy() const { return eviction_; } bool get_is_volatile() const { return is_volatile_; } protected: @@ -472,12 +484,6 @@ protected: if (cache_ == CG) return ".cg"; return ""; } - std::string get_eviction_policy_repr() const { - if (eviction_ == EVICT_FIRST) return ".L1::evict_first"; - if (eviction_ == EVICT_LAST) return ".L2::evict_last"; - return ""; - } - EVICTION_POLICY eviction_; CACHE_MODIFIER cache_; std::string get_volatile_repr() { @@ -553,7 +559,7 @@ public: // store class store_inst: public io_inst { protected: - store_inst(value *ptr, value_id_t id, unsigned num_ops, + store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); public: @@ -564,11 +570,11 @@ public: class unmasked_store_inst: public store_inst{ private: std::string repr_impl() const { return "unmasked_store"; } - unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next); + unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next); public: // factory method - static unmasked_store_inst* create(value* ptr, value *v, + static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(unmasked_store_inst) @@ -578,14 +584,14 @@ public: class masked_store_inst: public store_inst{ private: std::string repr_impl() const { return "masked_store"; } - masked_store_inst(value *ptr, value *v, value *mask, + masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction, const std::string &name, instruction *next); public: // accessors value *get_mask_operand() { return get_operand(2); } // factory method - static masked_store_inst* create(value *ptr, value *v, value *mask, + static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(masked_store_inst) @@ -755,6 +761,8 @@ private: class atomic_inst: public io_inst { public: using io_inst::io_inst; + atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next): + io_inst(ty, id, num_ops, NORMAL, name, next) {} }; class atomic_rmw_inst: public atomic_inst { @@ -856,6 +864,8 @@ public: bool is_prefetched() const { return is_prefetched_; } void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } bool allow_tf32() const { return allow_tf32_; } + bool is_trans_a() const { return AT_ == Trans; } + bool is_trans_b() const { return BT_ == Trans; } public: static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr); @@ -872,6 +882,8 @@ private: DataType C_type_ = DataType::FP32; DataType A_type_ = DataType::FP16; DataType B_type_ = DataType::FP16; + TransT AT_; + TransT BT_; }; //class outer_inst: public builtin_inst { diff --git a/include/triton/ir/utils.h b/include/triton/ir/utils.h index 893edd122..1fad79181 100644 --- a/include/triton/ir/utils.h +++ b/include/triton/ir/utils.h @@ -22,6 +22,7 @@ public: }; void for_each_instruction(ir::module& mod, const std::function &fn); +void for_each_instruction_backward(module &mod, const std::function &do_work); void for_each_value(ir::module& mod, const std::function &fn); } diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 3af40c2cc..f842c0f61 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -92,8 +92,10 @@ void allocation::run(ir::module &mod) { } // Save maximum size of induced memory space allocated_size_ = 0; - for(shared_layout* x: V) + for(shared_layout* x: V){ allocated_size_ = std::max(allocated_size_, starts[x] + x->get_size()); + // std::cout << "start: " << starts[x] << " | end: " << starts[x] + x->get_size() << std::endl; + } } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index a19be19ef..69c36f752 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -212,11 +212,9 @@ mma_layout::mma_layout(size_t num_warps, order_ = {0, 1}; } else{ - // fpw_ = {1, 1, 1}; spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32 contig_per_thread_ = {1, 2}; order_ = {1, 0}; - // rep_ = {2, 2, 1}; } /* warps per tile */ @@ -233,24 +231,45 @@ mma_layout::mma_layout(size_t num_warps, }while(wpt_nm1 != wpt_); } else { bool changed = false; - do { - changed = false; - if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) - break; - if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { - if (wpt_[0] < shape_[0] / spw_[0]) { - wpt_[0] *= 2; - changed = true; + // try to have a warp own entire rows of the output + // this makes it easier to fuse multiple mmas by fusing + // registers + bool one_warp_per_row = false; + for(ir::value* v: values) + for(ir::user* u: v->get_users()){ + auto* dot = dynamic_cast(u); + auto* cts = dynamic_cast(u); + if((dot && dot->get_operand(2)!=v) || !layout_a->to_shared() || cts) + one_warp_per_row = shape[0] / spw_[0] >= num_warps; + } + // std::cout << one_warp_per_row << std::endl; + + if(one_warp_per_row){ + wpt_[1] = 1; + wpt_[0] = num_warps; + } + else{ + do { + changed = false; + if (wpt_[0] * wpt_[1] * wpt_[2] >= num_warps) + break; + if (shape_[0] / spw_[0] / wpt_[0] >= shape_[1] / (spw_[1]*2) / wpt_[1]) { + if (wpt_[0] < shape_[0] / spw_[0]) { + wpt_[0] *= 2; + changed = true; + } + } else { + if (wpt_[1] < shape_[1] / (spw_[1]*2)) { + wpt_[1] *= 2; + changed = true; + } } - } else { - if (wpt_[1] < shape_[1] / (spw_[1]*2)) { - wpt_[1] *= 2; - changed = true; - } - } - } while (changed); + } while(changed); + } } + // std::cout << wpt_[0] << " " << wpt_[1] << std::endl; + /* shape per block */ shape_per_cta_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; } @@ -430,8 +449,8 @@ shared_layout::shared_layout(data_layout *arg, const std::vector& shape, const std::vector &values, ir::type *ty, - analysis::align* align, target *tgt) - : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) { + analysis::align* align, target *tgt, bool is_tmp) + : data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt), is_tmp_(is_tmp){ size_ = 0; arg_layout_ = arg; @@ -619,7 +638,7 @@ void layouts::create_tmp_layout(size_t id, data_layout *arg, ir::instruction *i, bool is_index) { ir::type *ty = is_index ? ir::type::get_int32_ty(i->get_type()->get_context()) : i->get_type()->get_scalar_ty(); - layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_); + layouts_[id] = new shared_layout(arg, axes, shape, {i}, ty, align_, tgt_, true); if (is_index) { tmp_index_[i] = id; } else { diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index 7beae21a1..535df4eb9 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -14,43 +14,108 @@ namespace analysis{ void liveness::run(ir::module &mod) { intervals_.clear(); - // Assigns index to each instruction - std::map indices; - for(ir::function *fn: mod.get_function_list()){ - slot_index index = 0; - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *instr: block->get_inst_list()){ - index += 1; - indices.insert({instr, index}); + std::map> layouts_map; + for(auto &x: layouts_->get_all()){ + shared_layout* layout = x.second->to_shared(); + if(!layout || layout->is_tmp()) + continue; + for(ir::value* v:layout->get_values()){ + layouts_map[v].insert(layout); } } - // create live intervals + + + std::map> live_in; + while(true){ + bool changed = false; + ir::instruction* last_inst = nullptr; + ir::for_each_instruction_backward(mod, [&](ir::instruction* i){ + // gen + std::set gen; + for(ir::value* v: i->ops()) + for(shared_layout* layout: layouts_map[v]) + gen.insert(layout); + // kill + std::set kill; + for(shared_layout* layout: layouts_map[i]) + kill.insert(layout); + // temporaries are handled separately + if(layouts_->has_tmp(i)){ + gen.insert(layouts_->get(layouts_->tmp(i))->to_shared()); + kill.insert(layouts_->get(layouts_->tmp(i))->to_shared()); + } + if(layouts_->has_tmp_index(i)){ + gen.insert(layouts_->get(layouts_->tmp_index(i))->to_shared()); + kill.insert(layouts_->get(layouts_->tmp_index(i))->to_shared()); + } + // live-out + std::set live_out; + std::vector succs = {last_inst}; + if(i == i->get_parent()->get_inst_list().back()) + for(ir::basic_block* succ: i->get_parent()->get_successors()) + succs.push_back(succ->get_inst_list().front()); + for(ir::instruction* succ: succs) + for(shared_layout* layout: live_in[succ]) + if(!layout->is_tmp()) + live_out.insert(layout); + + // new sets + std::set live_out_minus_kill; + std::set_difference(live_out.begin(), live_out.end(), kill.begin(), kill.end(), + std::inserter(live_out_minus_kill, live_out_minus_kill.end())); + std::set new_live_in; + std::set_union(gen.begin(), gen.end(), live_out_minus_kill.begin(), live_out_minus_kill.end(), + std::inserter(new_live_in, new_live_in.end())); + + changed = changed || (new_live_in != live_in[i]); + live_in[i] = new_live_in; + last_inst = i; + }); + if(!changed) + break; + } + + // ir::for_each_instruction(mod, [&](ir::instruction* i){ + // i->print(std::cout); + // std::cout << " live_in: " << live_in[i].size() << std::endl; + // }); + + + + // Assigns index to each instruction + std::map indices; + slot_index index = 0; + ir::for_each_instruction(mod, [&](ir::instruction* instr){ + index += 1; + indices.insert({instr, index}); + }); + + + for(auto &x: layouts_->get_all()){ + shared_layout* layout = x.second->to_shared(); + if(layout) + intervals_[layout] = segment{INT32_MAX, 0}; + } + + for(auto& x: live_in) + for(shared_layout* layout: x.second) + intervals_[layout].start = std::min(intervals_[layout].start, indices[x.first]); + + for(auto& x: live_in) + for(shared_layout* layout: x.second){ + intervals_[layout].end = std::max(intervals_[layout].end, indices[x.first] + 1); + } + + for(auto &x: layouts_->get_all()) { shared_layout* layout = x.second->to_shared(); if(!layout) continue; - // users - std::set users; - for(ir::value *v: layout->get_values()){ - for(ir::user *u: v->get_users()) - users.insert(u); - } - // compute intervals - unsigned start = INT32_MAX; - for(ir::value *v: layout->get_values()) - if(indices.find(v) != indices.end()) - start = std::min(start, indices.at(v)); - unsigned end = 0; - for(ir::user *u: users) - if(indices.find(u) != indices.end()) - end = std::max(end, indices.at(u)); - if(end == 0) - end = start + 1; - intervals_[layout] = segment{start, end}; + // std::cout << intervals_[layout].start << " " << intervals_[layout].end << std::endl; } - + } diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc index 5737f80a0..08843bbf7 100644 --- a/lib/codegen/analysis/swizzle.cc +++ b/lib/codegen/analysis/swizzle.cc @@ -28,12 +28,15 @@ void swizzle::run(ir::module &) { } auto ord = layout->get_order(); scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); - if(!in_layout) - continue; + int per_phase = 1; int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + if(in_layout) + per_phase = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + else + per_phase = 1; if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80){ int inner = mma_dot_a ? 0 : 1; - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + per_phase_[layout] = per_phase; max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; if(mma_dot_a) vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); @@ -46,7 +49,7 @@ void swizzle::run(ir::module &) { max_phase_[layout] = 1; vec_[layout] = 1; } else { - per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + per_phase_[layout] = per_phase; max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout]; vec_[layout] = layout->get_mma_vec(); } diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 4ba423d20..412e2f4c8 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -31,27 +31,28 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); // optimizations - bool cts_use_async = target->as_nvidia() && target->as_nvidia()->sm() >= 80; + bool has_sm80 = target->as_nvidia() && target->as_nvidia()->sm() >= 80; // create passes codegen::analysis::align align; codegen::transform::inliner inliner; codegen::analysis::axes axes; - codegen::transform::cts cts(cts_use_async); - codegen::transform::pipeline pipeline(cts_use_async, num_stages); + codegen::transform::pipeline pipeline(has_sm80, num_stages); codegen::transform::disassociate disassociate; codegen::analysis::layouts layouts(&axes, &align, num_warps, target); + codegen::transform::cts cts(&layouts, has_sm80); codegen::analysis::liveness liveness(&layouts); codegen::analysis::swizzle swizzle(&layouts, target); codegen::analysis::allocation allocation(&liveness); codegen::transform::dce dce; codegen::transform::peephole peephole(target, &layouts); - codegen::transform::coalesce coalesce(&align, &layouts); + codegen::transform::coalesce coalesce(&align, &layouts, has_sm80); codegen::transform::prefetch prefetch_s(target); codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); // run passes inliner.run(ir); dce.run(ir); + // ir.print(std::cout); peephole.run(ir); dce.run(ir); pipeline.run(ir); @@ -84,10 +85,15 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC axes.run(ir); layouts.run(ir); swizzle.run(ir); + // std::cout << "---" << std::endl; + // ir.print(std::cout); + // std::cout << "---" << std::endl; + // ir.print(std::cout); liveness.run(ir); allocation.run(ir); prefetch_s.run(ir); barriers.run(ir); + // exit(1); // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 8d95a2790..e69b0acee 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -744,11 +744,13 @@ void generator::visit_load_inst(ir::load_inst* x){ BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); + Value *lane = urem(tid, i32(32)); ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); // compute vector width size_t vec = 1; + bool is_mma_first_row = false; if(op->get_type()->is_block_ty()){ auto ord = ords_.at(op); size_t aln = alignment_->get(op, ord[0]); @@ -757,11 +759,15 @@ void generator::visit_load_inst(ir::load_inst* x){ max_eq = std::max(max_eq, 1); aln = std::min(aln, max_eq); } - auto layout = layouts_->get(x)->to_scanline(); - if(layout){ - size_t nts = layout->nts(ord[0]); - vec = std::min(nts, aln); - } + analysis::distributed_layout* layout = dynamic_cast(layouts_->get(x)); + assert(layout); + + vec = std::min(layout->contig_per_thread(ord[0]), aln); + // TODO: generalize + is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && + (a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1)); + if(is_mma_first_row) + vec = std::min(2, aln); } // code generation auto idxs = idxs_.at(x); @@ -795,8 +801,8 @@ void generator::visit_load_inst(ir::load_inst* x){ int tot_width = nbits*vec; int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); - bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; - has_evict_policy = false; // currently disable until supported in `store` + bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; + // has_evict_policy = false; // currently disable until supported in `store` // ----- // create inline asm string // ----- @@ -810,7 +816,7 @@ void generator::visit_load_inst(ir::load_inst* x){ if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; - if (has_evict_policy) asm_oss << ".L2::cache_hint"; + if (has_l2_evict_policy) asm_oss << ".L2::cache_hint"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size @@ -822,7 +828,7 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_oss << "}"; asm_oss << ", [ $" << n_words + 1; // load asm_oss << " + " << in_off << "]"; // constant offset - if (has_evict_policy) asm_oss << ", $" << n_words + 2; + if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; asm_oss << ";"; bool has_other = other && (other != UndefValue::get(other->getType())); std::vector others; @@ -844,7 +850,7 @@ void generator::visit_load_inst(ir::load_inst* x){ if(ConstantInt* cst = dyn_cast(v)) asm_oss << "0x" << std::hex << cst->getSExtValue(); else{ - asm_oss << "$" << n_words + has_evict_policy + 2 + ii; + asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii; others.push_back(v); } asm_oss.flags(flags); @@ -859,7 +865,7 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); - if (has_evict_policy) + if (has_l2_evict_policy) arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); // --- @@ -875,7 +881,7 @@ void generator::visit_load_inst(ir::load_inst* x){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } - if (has_evict_policy) + if (has_l2_evict_policy) asm_cstrt += ",l"; // --- // finally call inline ASM @@ -884,7 +890,7 @@ void generator::visit_load_inst(ir::load_inst* x){ std::vector args = {pred, ptr}; for(Value *v: others) args.push_back(v); - if (has_evict_policy) + if (has_l2_evict_policy) args.push_back(policies_.at(x->get_eviction_policy())); @@ -935,6 +941,9 @@ void generator::visit_store_inst(ir::store_inst * x){ // operands ir::value *ptr_op = x->get_pointer_operand(); ir::value *val_op = x->get_value_operand(); + ir::value *msk_op = nullptr; + if(auto* msk_st = dynamic_cast(x)) + msk_op = msk_st->get_mask_operand(); // vector size size_t vec = 1; if(val_op->get_type()->is_block_ty()){ @@ -946,36 +955,107 @@ void generator::visit_store_inst(ir::store_inst * x){ max_eq = std::max(max_eq, 1); aln = std::min(aln, max_eq); } - vec = std::min(nts, aln); + analysis::distributed_layout* layout = dynamic_cast(layouts_->get(ptr_op)); + assert(layout); + // vec = std::min(nts, aln); + vec = std::min(layout->contig_per_thread(ord[0]), aln); + // TODO: generalize + bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && + (a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1)); + if(is_mma_first_row) + vec = std::min(2, aln); } + bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store ty = f16_ty; + if(ty->isIntegerTy(1)) + ty = builder_->getInt8Ty(); for(size_t i = 0; i < idxs.size(); i += vec){ - auto idx = idxs[i]; - // pointer + indices_t idx = idxs[i]; + // pointers Value *ptr = vals_[ptr_op][idx]; - // vectorize - Type *v_ty = vec_ty(ty, vec); - ptr = bit_cast(ptr, v_ty->getPointerTo(1)); - // value - Value* val = UndefValue::get(v_ty); - for(size_t ii = 0; ii < vec; ii++) - val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii); - if(mx){ - Value *msk = vals_[mx->get_mask_operand()][idx]; - Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); - builder_->SetInsertPoint(no_op->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false); - dummy->removeFromParent(); - builder_->SetInsertPoint(term); - store(val, ptr); - builder_->SetInsertPoint(no_op); + size_t dtsize = std::max(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8); + GetElementPtrInst *in_gep = dyn_cast(ptr); + size_t in_off; + if(in_gep){ + ConstantInt* cst = dyn_cast(in_gep->idx_begin()); + in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; + ptr = cst ? in_gep->getPointerOperand() : in_gep; } - else - store(val, ptr); + else{ + in_off = 0; + } + // mask + Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue(); + size_t nbits = dtsize*8; + // pack sub-words (< 32/64bits) into words + // each load has width min(nbits*vec, 32/64) + // and there are (nbits * vec)/width of them + int max_word_width = std::max(32, nbits); + int tot_width = nbits*vec; + int width = std::min(tot_width, max_word_width); + int n_words = std::max(1, tot_width / width); + // ----- + // create inline asm string + // ----- + std::ostringstream asm_oss; + asm_oss << "@$0"; // predicate + asm_oss << " st.global"; + if (has_l2_evict_policy) asm_oss << ".L2::cache_hint"; + if(n_words > 1) + asm_oss << ".v" << n_words; // vector width + asm_oss << ".b" << width; // word size + asm_oss << " [ $1 + " << in_off << "]"; + asm_oss << " , {"; + for(int i = 0; i < n_words; i++){ // return values + if(i > 0) asm_oss << ","; + asm_oss << "$" << 2 + i; + } + asm_oss << "}"; + if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; + asm_oss << ";"; + // ---- + // create inline ASM signature + // --- + Type* val_arg_ty = IntegerType::get(*ctx_, width); + std::vector arg_tys = {pred->getType(), ptr->getType()}; + for(int ii = 0; ii < n_words; ii++) + arg_tys.push_back(val_arg_ty); + if (has_l2_evict_policy) + arg_tys.push_back(i64_ty); + FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false); + // --- + // create inline ASM constraints + // --- + std::string asm_cstrt = "b,l"; + for(int ii = 0; ii < n_words; ii++){ + asm_cstrt += ","; + asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + } + if (has_l2_evict_policy) + asm_cstrt += ",l"; + // --- + // finally call inline ASM + // --- + InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); + std::vector args = {pred, ptr}; + for(unsigned int ii = 0; ii < n_words; ii++){ + size_t n_subw = width / nbits; + Value* curr = UndefValue::get(vec_ty(ty, n_subw)); + for(unsigned int jj = 0; jj < n_subw; jj++){ + Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]]; + if(new_elt->getType()->isIntegerTy(1)) + new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty()); + new_elt = bit_cast(new_elt, ty); + curr = builder_->CreateInsertElement(curr, new_elt, jj); + } + args.push_back(bit_cast(curr, val_arg_ty)); + } + if (has_l2_evict_policy) + args.push_back(policies_.at(x->get_eviction_policy())); + call(_asm, args); } } void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) { @@ -1098,6 +1178,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){ InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false); for(auto idx: idxs_.at(x)){ Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); + // Value *ex2arg = vals_[x->get_operand(0)][idx]; vals_[x][idx] = call(ex2, std::vector{ex2arg}); } } @@ -1291,6 +1372,18 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va // order auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); + bool is_a_trans = C->is_trans_a(); + // is_a_trans = false; + if(C->is_trans_a()){ + std::swap(ord_a[0], ord_a[1]); + std::swap(shape_a[0], shape_a[1]); + std::swap(offset_a_m_, offset_a_k_); + } + // std::cout << "visiting" << std::endl; + // if(C->is_trans_b()){ + // std::swap(ord_b[0], ord_b[1]); + // std::swap(shape_b[0], shape_b[1]); + // } // layouts analysis::mma_layout* layout_c = layouts_->get(C)->to_mma(); analysis::shared_layout* layout_a = layouts_->get(A)->to_shared(); @@ -1322,6 +1415,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va int step_b0 = is_b_row ? stride_rep_n : stride_rep_k; int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1); + + // max_phase_a = 4; + // vec_a = 8; + // std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl; + // std::cout << vec_a << " " << vec_b << std::endl; + /* --------------------------------- */ /* --- pre-compute pointer lanes --- */ /* --------------------------------- */ @@ -1916,12 +2015,17 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: auto shape_a = A->get_type()->get_block_shapes(); auto shape_b = B->get_type()->get_block_shapes(); auto ord_a = layouts_->get(A)->get_order(); + if(C->is_trans_a()){ + std::swap(ord_a[0], ord_a[1]); + std::swap(shape_a[0], shape_a[1]); + } auto ord_b = layouts_->get(B)->get_order(); + if(C->is_trans_b()){ + std::swap(ord_b[0], ord_b[1]); + std::swap(shape_b[0], shape_b[1]); + } + NK = shape_a[1]; analysis::mma_layout* layout = layouts_->get(C)->to_mma(); - analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); - analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); - bool is_a_row = ord_a[0] == 1; - bool is_b_row = ord_b[0] == 1; std::vector mma_instr_shape = layout->get_mma_instr_shape(); const int mma_instr_m = mma_instr_shape[0]; @@ -1933,10 +2037,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: const int mat_shape_n = mat_shape[1]; const int mat_shape_k = mat_shape[2]; - const int per_phase_a = swizzle_->get_per_phase(layout_a); - const int max_phase_a = swizzle_->get_max_phase(layout_a); - const int per_phase_b = swizzle_->get_per_phase(layout_b); - const int max_phase_b = swizzle_->get_max_phase(layout_b); const int num_rep_m = shapes[0] / layout->shape_per_cta(0); const int num_rep_n = shapes[1] / layout->shape_per_cta(1); @@ -2001,7 +2101,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); - if(FirstBB != CurrBB) + + // if true, this will move pointer declarations to the entry basic block + // not prefetched cases tend to be more limited in resource usage + // so we don't pre-compute ptrs to save registers + bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB); + if(licm_ptrs) builder_->SetInsertPoint(FirstBB->getTerminator()); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); @@ -2015,47 +2120,137 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + ir::phi_node* phiA = dynamic_cast(A); + ir::phi_node* phiB = dynamic_cast(B); + auto register_lds2 = + [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { + if (k < 2 && is_prefetch) { + ir::basic_block* inc_block = phiA->get_incoming_block(inc); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); + } else + vals[{mn, k}] = val; + }; + // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride // v (s0_0(0), s1_0(2), | *num_rep_k // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2) // ----------- // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0)) - mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, - {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, - per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); - std::vector off_a = a_loader.compute_offs(warp_m, lane); - int num_ptr_a = a_loader.get_num_ptr(); + std::function load_a; + analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared(); + bool is_a_shared = layout_a != nullptr; + if(is_a_shared) { + const int per_phase_a = swizzle_->get_per_phase(layout_a); + const int max_phase_a = swizzle_->get_max_phase(layout_a); + mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, + {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, + per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); + std::vector off_a = a_loader.compute_offs(warp_m, lane); + int num_ptr_a = a_loader.get_num_ptr(); + // pointers + std::vector ptrs_a(num_ptr_a); + if(licm_ptrs) + builder_->SetInsertPoint(CurrBB); + for(int i = 0; i < num_ptr_a; i++) + ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); + if(licm_ptrs) + builder_->SetInsertPoint(FirstBB->getTerminator()); + // loading function + load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable { + auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], + shared_next_ptr_[layout_a], off_a, ptrs_a, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(ha, m, k, inc, ha0, is_prefetch); + register_lds2(ha, m+1, k, inc, ha1, is_prefetch); + register_lds2(ha, m, k+1, inc, ha2, is_prefetch); + register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); + }; + } + else { + load_a = [&](int m, int k, int inc, bool is_prefetch) { + distributed_axis ax_n = axes_.at(a_axes_->get(A, 1)); + int ldm = ax_n.values.size(); + if(ldm != num_rep_k*4) + throw std::runtime_error("Internal compiler error when trying to fuse matmuls!"); + // std::cout << m << " " << k << std::endl; + // std::cout << idxs_[A].size() << std::endl; + // std::cout << (m+1)*ldm + k*2 + 3 << std::endl; + // int ldm = num_rep_k*4; + Value* ha0 = UndefValue::get(fp16x2_ty); + Value* ha1 = UndefValue::get(fp16x2_ty); + Value* ha2 = UndefValue::get(fp16x2_ty); + Value* ha3 = UndefValue::get(fp16x2_ty); + ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0)); + ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1)); + ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0)); + ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1)); + ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0)); + ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1)); + ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0)); + ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1)); + ha[{m, k}] = ha0; + ha[{m+1, k}] = ha1; + ha[{m, k+1}] = ha2; + ha[{m+1, k+1}] = ha3; + }; + } + // | -> n (col-major) // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1)) // ----------- // *num_rep_k (stride in num of matrices(mat_stride_bk): 2) - mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b, - {mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n}, + analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared(); + const int per_phase_b = swizzle_->get_per_phase(layout_b); + const int max_phase_b = swizzle_->get_max_phase(layout_b); + std::vector mma_instr_b{mma_instr_k, mma_instr_n}; + std::vector mat_shape_b{mat_shape_k, mat_shape_n}; + int k_order_b = 0; + // if(C->is_trans_b()){ + // std::swap(mma_instr_b[0], mma_instr_b[1]); + // std::swap(mat_shape_b[0], mat_shape_b[1]); + // k_order_b = k_order_b ^ 1; + // std::swap(ord_b[0], ord_b[1]); + // std::swap(shape_b[0], shape_b[1]); + // } + + mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b, + mma_instr_b, mat_shape_b, per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep); std::vector off_b = b_loader.compute_offs(warp_n, lane); - int num_ptr_b = b_loader.get_num_ptr(); - builder_->SetInsertPoint(CurrBB); - // A pointer - std::vector ptrs_a(num_ptr_a); - for(int i = 0; i < num_ptr_a; i++) - ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); - // B pointer + if(licm_ptrs) + builder_->SetInsertPoint(CurrBB); + // pointers + int num_ptr_b = b_loader.get_num_ptr(); std::vector ptrs_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++) ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); - InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + + + // loading function + std::function load_b; + load_b = [&](int n, int k, int inc, bool is_prefetch) { + auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], + shared_next_ptr_[layout_b], off_b, ptrs_b, + ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); + register_lds2(hb, n, k, inc, hb0, is_prefetch); + register_lds2(hb, n+1, k, inc, hb2, is_prefetch); + register_lds2(hb, n, k+1, inc, hb1, is_prefetch); + register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); + }; + + + + // create mma & unpack result, m, n, k are offsets in mat + auto call_mma = [&](unsigned m, unsigned n, unsigned k) { + InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + " {$0, $1, $2, $3}," " {$4, $5, $6, $7}," " {$8, $9}," " {$10, $11, $12, $13};", "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true); - - // create mma & unpack result, m, n, k are offsets in mat - auto call_mma = [&](unsigned m, unsigned n, unsigned k) { unsigned cols_per_thread = num_rep_n * 2; std::vector idx = { (m + 0)*cols_per_thread + (n*2 + 0), @@ -2072,39 +2267,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: fc[idx[2]] = extract_val(nc, std::vector{2}); fc[idx[3]] = extract_val(nc, std::vector{3}); }; - - ir::phi_node* phiA = dynamic_cast(A); - ir::phi_node* phiB = dynamic_cast(B); - - auto register_lds2 = - [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { - if (k < 2 && is_prefetch) { - ir::basic_block* inc_block = phiA->get_incoming_block(inc); - lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); - } else - vals[{mn, k}] = val; - }; - - auto load_a = [&](int m, int k, int inc, bool is_prefetch) { - auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], - shared_next_ptr_[layout_a], off_a, ptrs_a, - ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); - register_lds2(ha, m, k, inc, ha0, is_prefetch); - register_lds2(ha, m+1, k, inc, ha1, is_prefetch); - register_lds2(ha, m, k+1, inc, ha2, is_prefetch); - register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); - }; - - auto load_b = [&](int n, int k, int inc, bool is_prefetch) { - auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], - shared_next_ptr_[layout_b], off_b, ptrs_b, - ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); - register_lds2(hb, n, k, inc, hb0, is_prefetch); - register_lds2(hb, n+1, k, inc, hb2, is_prefetch); - register_lds2(hb, n, k+1, inc, hb1, is_prefetch); - register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); - }; - if (C->is_prefetched()) { // create phis builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); @@ -2163,6 +2325,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir:: i = 0; vals_[C][idx] = fcs.at(key)[i++]; }; + } /** @@ -2384,7 +2547,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va } else if (layout->to_mma()) { shuffle_width = 4; warps_per_inner = layout->to_mma()->wpt(1); - col_per_thread = 16; + col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size(); warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; } assert(warp_j != nullptr); @@ -2403,7 +2566,8 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va Value* is_warp0 = icmp_eq(warp, i32(0)); Value* is_thread0 = icmp_eq(thread, i32(0)); Value* lane_j = urem(lane, i32(shuffle_width)); - add_barrier(); + if(warps_per_inner > 1) + add_barrier(); // compute partial sum for each warp, and store to shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ std::pair acc; @@ -2425,13 +2589,21 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va // store partial result to shared memory auto x_idxs = idxs_[x][i]; Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; - Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); - call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); - if (with_index) { - call(st_shared_index, - {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); + // single warp on the reduce dimension -- no need to use shmem + if(warps_per_inner==1){ + vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; + } + else{ + Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); + call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); + if (with_index) { + call(st_shared_index, + {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); + } } } + if(warps_per_inner==1) + return; add_barrier(); // at this point, partial accumulator synchronized in shared memory // Just need to reduce `warp_per_inner` numbers in shared memory @@ -2559,6 +2731,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { case ir::reduce_inst::FMAX: return max_num(x, y); case ir::reduce_inst::FMIN: return min_num(x, y); case ir::reduce_inst::XOR: return xor_(x, y); + default: throw std::runtime_error("unreachable"); } }; @@ -2639,7 +2812,9 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); Value *base; - base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); + int off = alloc_->offset(layouts_->get(layouts_->tmp(out))); + // std::cout << off << std::endl; + base = gep(shmem_, i32(off)); base = bit_cast(base, ptr_ty(ty, 3)); std::vector n_reps; for(int i = 0; i < shape.size(); i++){ @@ -2821,15 +2996,26 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); + if(in_layout->to_mma()){ + mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]); + mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]); + per_phase = 1; + max_phase = 8; + } int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; - int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); int n_shared_0 = std::max(in_vec / out_vec, 1); + int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); + if(in_layout->to_mma()){ + n_shared_0 = 8; + n_shared_1 = 1; + } BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); auto shapes = cts->get_type()->get_block_shapes(); + // store to shared Value *current = nullptr; std::map, Value*> ptrs; @@ -2844,9 +3030,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); - int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; - int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; - int off = (off_1*shapes[in_order[0]] + off_0); + // std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl; std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ if(FirstBB->getTerminator()) @@ -2865,6 +3049,13 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { builder_->SetInsertPoint(CurrBB); ptrs[key] = gep(shmems_.at(cts), {off}); } + int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; + int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; + if(in_layout->to_mma()){ + off_0 = id_0/n_shared_0*n_shared_0*8; + off_1 = id_1/n_shared_1*n_shared_1*8; + } + int off = (off_1*shapes[in_order[0]] + off_0); Value* ptr = gep(ptrs[key], {i32(off)}); ptr = bit_cast(ptr, current->getType()->getPointerTo(3)); // asm @@ -3069,7 +3260,7 @@ void generator::visit_function(ir::function* fn) { if(tgt_->as_nvidia()->sm() >= 80) for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){ std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last"; - std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;"; + std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;"; InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false); policies_[evict] = call(iasm); } @@ -3299,7 +3490,6 @@ void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ - // i->print(std::cout); visit_value(i); // std::cout << "done" << std::endl; } @@ -3324,7 +3514,10 @@ void generator::init_idx(ir::value *v) { std::vector axes(rank); std::vector ord(rank); // compute axes + // std::cout << "axes" << std::endl; for(size_t d = 0; d < shapes.size(); d++){ + // std::cout << d << " " << shapes[d] << std::endl; + // std::cout << a_axes_->get(v, d) << std::endl; if(shapes[d] > 1){ unsigned x = a_axes_->get(v, d); axes[d] = axes_.at(x); @@ -3334,6 +3527,7 @@ void generator::init_idx(ir::value *v) { axes[d].values = {i32(0)}; } } + // std::cout << "axes ok" << std::endl; // compute order analysis::data_layout* layout = layouts_->get(v); std::iota(ord.begin(), ord.end(), 0); @@ -3480,6 +3674,7 @@ void generator::finalize_phi_node(ir::phi_node *x) { for(indices_t idx: idxs_.at(x)){ PHINode *phi = (PHINode*)vals_[x][idx]; Value *inc = vals_[x->get_incoming_value(n)][idx]; + // x->print(std::cout); phi->addIncoming(inc, block); } } diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 8092ac527..8b5ad3625 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -12,8 +12,8 @@ namespace triton { namespace codegen{ namespace transform{ -coalesce::coalesce(analysis::align* align, analysis::layouts *layouts) - : align_(align), layout_(layouts) { } +coalesce::coalesce(analysis::align* align, analysis::layouts *layouts, bool has_sm80) + : align_(align), layout_(layouts), has_sm80_(has_sm80) { } // simplify layout conversions using the following simple rules: @@ -64,15 +64,18 @@ void coalesce::run(ir::module &mod) { if(op->get_type()->is_block_ty()) if(op->get_type()->get_tile_rank() == 2) if(invalidated.find(layout_->get(op)) == invalidated.end()) - if(layout_->get(op)->to_mma()){ + if(layout_->get(op)->to_mma()) + if(dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ ir::instruction* new_op = ir::cvt_layout_inst::create(op); builder.set_insert_point(i); builder.insert(new_op); i->replace_uses_of_with(op, new_op); } // coalesce before copy_to_shared - // It's dirty, but the backend is being rewritten from scratch. :) - if(dynamic_cast(i)) + // only necessary for sm < 80 as Ampere+ can handle reduction + // on MMA layout + if(!has_sm80_) + if(dynamic_cast(i) || dynamic_cast(i)) if(ir::value* op = i->get_operand(0)) if(op->get_type()->is_block_ty()) if(op->get_type()->get_tile_rank() == 2) @@ -89,7 +92,8 @@ void coalesce::run(ir::module &mod) { if(auto x = dynamic_cast(i)) if(x->get_type()->is_block_ty()) if(x->get_type()->get_tile_rank()==2) - if(layout_->get(x)->to_mma()){ + if(layout_->get(x)->to_mma()) + if(!has_sm80_ || dynamic_cast(i)->get_eviction_policy()==ir::io_inst::NORMAL){ builder.set_insert_point_after(x); ir::instruction* new_x = ir::cvt_layout_inst::create(x); builder.insert(new_x); diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index c223d2413..4606b0f57 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -1,8 +1,10 @@ +#include "triton/codegen/analysis/layout.h" #include "triton/codegen/transform/cts.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" +#include "triton/ir/utils.h" #include namespace triton { @@ -10,9 +12,9 @@ namespace codegen{ namespace transform{ -inline bool is_shmem_op(ir::instruction* i, int op) { +bool cts::is_shmem_op(ir::instruction* i, int op) { if(i->get_id() == ir::INST_DOT) - return op==0 || op==1; + return op == 0 || op == 1; if(i->get_id() == ir::INST_COPY_FROM_SHARED) return op==0; if(i->get_id() == ir::INST_TRANS) @@ -20,7 +22,7 @@ inline bool is_shmem_op(ir::instruction* i, int op) { return false; } -inline bool is_shmem_res(ir::value* v){ +bool cts::is_shmem_res(ir::value* v){ ir::instruction* i = dynamic_cast(v); if(!i) return false; @@ -35,7 +37,7 @@ inline bool is_shmem_res(ir::value* v){ // run pass on module -void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { +void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map& copies) { auto *i = dynamic_cast(x); // not an instruction if(!i) { @@ -51,7 +53,7 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, // phi node if(auto* phi = dynamic_cast(x)) { for(unsigned i = 0; i < phi->get_num_incoming(); ++i) - add_copy(phi, phi->get_incoming_value(i), builder, to_shared); + add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies); return; } // already in shared memory @@ -65,30 +67,49 @@ void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, } else copy = builder.create_copy_from_shared(x); - parent->replace_uses_of_with(x, copy); + copies.insert({x, copy}); + parent->replace_uses_of_with(x, copies.at(x)); } void cts::run(ir::module &mod) { - // Add shared copies - ir::builder &builder = mod.get_builder(); - for(ir::function* fn: mod.get_function_list()){ - for(ir::basic_block* block: fn->blocks()) - for(ir::instruction* i: block->get_inst_list()){ - size_t num_op = i->get_num_operands(); - // copy to shared operands - for(size_t k = 0; k < num_op; k++) - if(is_shmem_op(i, k)){ - add_copy(i, i->get_operand(k), builder, true); - } - // copy from shared operands - for(size_t k = 0; k < num_op; k++) - if(!dynamic_cast(i) && - !is_shmem_op(i,k) && - is_shmem_res(i->get_operand(k))){ - add_copy(i, i->get_operand(k), builder, false); - } + // Precompute where copies should be added + std::set shmem_ops; + std::set shmem_res; + ir::for_each_instruction(mod, [&](ir::instruction* i) { + if(i->get_id() == ir::INST_DOT){ + ir::dot_inst* dot = dynamic_cast(i); + ir::value* lhs = i->get_operand(0); + ir::type* ty = lhs->get_type()->get_scalar_ty(); + analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma(); + // TODO: V100 + bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a()); + if(is_lhs_shmem) + shmem_ops.insert(lhs); + shmem_ops.insert(i->get_operand(1)); } - } + if(i->get_id() == ir::INST_COPY_FROM_SHARED) + shmem_ops.insert(i->get_operand(0)); + if(i->get_id() == ir::INST_TRANS) + shmem_ops.insert(i->get_operand(0)); + if(i->get_id() == ir::INST_TRANS || + i->get_id() == ir::INST_COPY_TO_SHARED || + i->get_id() == ir::INST_MASKED_LOAD_ASYNC) + shmem_res.insert(i); + }); + + // Add shared copies + std::map copies; + ir::builder &builder = mod.get_builder(); + ir::for_each_instruction(mod, [&](ir::instruction* i) { + size_t num_op = i->get_num_operands(); + for(size_t k = 0; k < num_op; k++){ + ir::value* op = i->get_operand(k); + // copy to shared operands + bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end(); + if(is_shmem_op) + add_copy(i, op, builder, true, copies); + } + }); } diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index c25a252a8..a7d3f8240 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -87,7 +87,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); builder.set_insert_point(add); - ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name())); + ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other, dot->is_trans_a(), dot->is_trans_b(), dot->allow_tf32(), dot->get_name())); add->replace_all_uses_with(new_dot); return true; } diff --git a/lib/ir/basic_block.cc b/lib/ir/basic_block.cc index 93caef2c3..0bbc3af0f 100644 --- a/lib/ir/basic_block.cc +++ b/lib/ir/basic_block.cc @@ -26,7 +26,10 @@ void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) auto* curr_phi = dynamic_cast(i); if(!curr_phi) break; - curr_phi->replace_uses_of_with(before, after); + // curr_phi->replace_uses_of_with(before, after); + for (size_t idx = 0; idx < curr_phi->get_num_incoming(); ++idx) + if (curr_phi->get_incoming_block(idx) == before) + curr_phi->set_incoming_block(idx, after); } } diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 4060f23bb..510994fd8 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -299,16 +299,16 @@ value *builder::create_load(value *ptr, load_inst::CACHE_MODIFIER cache, load_in return insert(unmasked_load_inst::create(ptr, cache, eviction, is_volatile)); } -value *builder::create_store(value *ptr, value *val){ - return insert(unmasked_store_inst::create(ptr, val)); +value *builder::create_store(value *ptr, value *val, store_inst::EVICTION_POLICY eviction){ + return insert(unmasked_store_inst::create(ptr, val, eviction)); } value *builder::create_masked_load(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile){ return insert(masked_load_inst::create(ptr, mask, false_value, cache, eviction, is_volatile)); } -value *builder::create_masked_store(value *ptr, value *val, value *mask){ - return insert(masked_store_inst::create(ptr, val, mask)); +value *builder::create_masked_store(value *ptr, value *val, value *mask, store_inst::EVICTION_POLICY eviction){ + return insert(masked_store_inst::create(ptr, val, mask, eviction)); } //===----------------------------------------------------------------------===// @@ -412,8 +412,8 @@ value *builder::create_log(value *arg){ return insert(log_inst::create(arg)); } -value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) { - return insert(dot_inst::create_nn(A, B, C, allow_tf32)); +value *builder::create_dot(value *A, value *B, value *C, bool trans_a, bool trans_b, bool allow_tf32) { + return insert(dot_inst::create(A, B, C, trans_a, trans_b, allow_tf32)); } value *builder::create_trans(value *A, const std::vector& perm) { diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 325976504..dbee5e0ee 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -69,6 +69,7 @@ void phi_node::set_incoming_block(unsigned i, basic_block *block){ // Add incoming void phi_node::add_incoming(value *v, basic_block *block){ + assert(v && "PHI node got a null value!!"); resize_ops(get_num_operands() + 1); blocks_.resize(get_num_operands() + 1); set_incoming_value(get_num_operands() - 1, v); @@ -494,13 +495,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vectorget_type()), id, num_ops, name, next), cache_(cache), eviction_(eviction), is_volatile_(is_volatile) + : io_inst(get_pointee_type(ptr->get_type()), id, num_ops, eviction, name, next), cache_(cache), is_volatile_(is_volatile) { } // load @@ -557,34 +558,35 @@ masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, // store -store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next) - : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next) +store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, const std::string &name, instruction *next) + : io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, eviction, name, next) { } // unmasked_store -unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, +unmasked_store_inst::unmasked_store_inst(value *ptr, value *val, EVICTION_POLICY eviction, const std::string &name, instruction *next) - : store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) { + : store_inst(ptr, INST_UNMASKED_STORE, 2, eviction, name, next) { set_operand(0, ptr); set_operand(1, val); } -unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, +unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val, EVICTION_POLICY eviction, const std::string &name, instruction *next) { - return new unmasked_store_inst(ptr, val, name, next); + return new unmasked_store_inst(ptr, val, eviction, name, next); } // masked store -masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, +masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, EVICTION_POLICY eviction, const std::string &name, instruction *next) - : store_inst(ptr, INST_MASKED_STORE, 3, name, next) { + : store_inst(ptr, INST_MASKED_STORE, 3, eviction, name, next) { set_operand(0, ptr); set_operand(1, val); set_operand(2, mask); } -masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) { - return new masked_store_inst(ptr, val, mask, name, next); +masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, EVICTION_POLICY eviction, + const std::string &name, instruction *next) { + return new masked_store_inst(ptr, val, mask, eviction, name, next); } //===----------------------------------------------------------------------===// @@ -679,7 +681,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next) - : builtin_inst(C->get_type(), INST_DOT, 3, name, next) { + : builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT){ set_operand(0, A); set_operand(1, B); set_operand(2, C); diff --git a/lib/ir/utils.cc b/lib/ir/utils.cc index cbfb4baf9..9abaef5c0 100644 --- a/lib/ir/utils.cc +++ b/lib/ir/utils.cc @@ -43,6 +43,15 @@ std::vector cfg::reverse_post_order(function* fn) { return result; } +void for_each_instruction_backward(module &mod, const std::function &do_work) { + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: cfg::post_order(fn)){ + auto inst_list = block->get_inst_list(); + for(auto it = inst_list.rbegin(); it != inst_list.rend() ; it++) + do_work(*it); + } +} + void for_each_instruction(module &mod, const std::function &do_work) { for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: cfg::reverse_post_order(fn)) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a5fb0acba..6987d0c26 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -840,10 +840,10 @@ def test_permute(dtype_str, shape, perm, device='cuda'): @pytest.mark.parametrize("epilogue, allow_tf32, dtype", [(epilogue, allow_tf32, dtype) - for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] - for dtype in ['float32', 'int8'] - if not (allow_tf32 and (dtype == 'int8'))]) + for dtype in ['float16'] + if not (allow_tf32 and (dtype in ['float16']))]) def test_dot(epilogue, allow_tf32, dtype, device='cuda'): cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device()) if cc < 80: @@ -852,21 +852,30 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): elif dtype == 'float32' and allow_tf32: pytest.skip("Only test tf32 on devices with sm >= 80") + M, N, K = 128, 128, 64 + num_warps = 8 + trans_a, trans_b = False, False + # triton kernel @triton.jit def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, + W, stride_wn, stride_wl, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - ALLOW_TF32: tl.constexpr): + ALLOW_TF32: tl.constexpr, + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, + TRANS_A: tl.constexpr, TRANS_B: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) + off_l = tl.arange(0, BLOCK_N) off_k = tl.arange(0, BLOCK_K) Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn + Ws = W + off_n[:, None] * stride_wn + off_l[None, :] * stride_wl Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn - z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32) + z = tl.dot(tl.load(Xs), tl.load(Ys), trans_a=TRANS_A, trans_b=TRANS_B, allow_tf32=ALLOW_TF32) if ADD_MATRIX: z += tl.load(Zs) if ADD_ROWS: @@ -875,39 +884,65 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): if ADD_COLS: ZCs = Z + off_n * stride_zn z += tl.load(ZCs)[None, :] + if DO_SOFTMAX: + max = tl.max(z, 1) + z = z - max[:, None] + num = tl.exp(z) + den = tl.sum(num, 1) + z = num / den[:, None] + if CHAIN_DOT: + # tl.store(Zs, z) + # tl.debug_barrier() + z = tl.dot(z.to(tl.float16), tl.load(Ws), trans_a=TRANS_A) tl.store(Zs, z) # input - M, N, K = 64, 64, 32 rs = RandomState(17) - x = numpy_random((M, K), dtype_str=dtype, rs=rs) - y = numpy_random((K, N), dtype_str=dtype, rs=rs) + x = numpy_random((K, M) if trans_a else (M, K), dtype_str=dtype, rs=rs) * .1 + y = numpy_random((N, K) if trans_b else (K, N), dtype_str=dtype, rs=rs) * .1 + w = numpy_random((N, N), dtype_str=dtype, rs=rs) * .1 if allow_tf32: x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32') + w = (w.view('uint32') & np.uint32(0xffffe000)).view('float32') x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) + w_tri = to_triton(w, device=device) # triton result - z = numpy_random((M, N), dtype_str=dtype, rs=rs) + z = 1 + numpy_random((M, N), dtype_str=dtype, rs=rs) * .1 z_tri = to_triton(z, device=device) if epilogue == 'trans': z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), + w_tri, w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), + TRANS_A=trans_a, TRANS_B=trans_b, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', ADD_ROWS=epilogue == 'add-rows', ADD_COLS=epilogue == 'add-cols', - ALLOW_TF32=allow_tf32) + DO_SOFTMAX=epilogue == 'softmax', + CHAIN_DOT=epilogue == 'chain-dot', + ALLOW_TF32=allow_tf32, + num_warps=num_warps) # torch result - z_ref = np.matmul(x, y) + x_ref = x.T if trans_a else x + y_ref = y.T if trans_b else y + z_ref = np.matmul(x_ref, y_ref) if epilogue == 'add-matrix': z_ref += z if epilogue == 'add-rows': z_ref += z[:, 0][:, None] if epilogue == 'add-cols': z_ref += z[0, :][None, :] + if epilogue == 'softmax': + num = np.exp(z_ref - np.max(z_ref, axis=-1, keepdims=True)) + denom = np.sum(num, axis=-1, keepdims=True) + z_ref = num / denom + if epilogue == 'chain-dot': + z_ref = np.matmul(z_ref.T if trans_a else z_ref, w) # compare + # print(z_ref[:,0], z_tri[:,0]) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized ptx = pgm.asm['ptx'] diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 90a031a30..27c9e1bfe 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -211,7 +211,7 @@ class ValueConstructor: return phi v = unique_handles.pop() phi.handle.replace_all_uses_with(v) - phi.handle.erase_from_parent() + # phi.handle.erase_from_parent() # TODO: remove trivial phis recursively return triton.language.tensor(v, phi.type) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 1c54ef2c7..3ba2d2f21 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -732,7 +732,7 @@ def reshape(input, shape, _builder=None): @builtin -def dot(input, other, allow_tf32=True, _builder=None): +def dot(input, other, trans_a=False, trans_b=False, allow_tf32=True, _builder=None): """ Returns the matrix product of two blocks. @@ -744,7 +744,7 @@ def dot(input, other, allow_tf32=True, _builder=None): :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} """ allow_tf32 = _constexpr_to_value(allow_tf32) - return semantic.dot(input, other, allow_tf32, _builder) + return semantic.dot(input, other, trans_a, trans_b, allow_tf32, _builder) # ----------------------- @@ -782,7 +782,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", @builtin -def store(pointer, value, mask=None, _builder=None): +def store(pointer, value, eviction_policy="", mask=None, _builder=None): """ Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. @@ -799,7 +799,7 @@ def store(pointer, value, mask=None, _builder=None): value = _to_tensor(value, _builder) if mask is not None: mask = _to_tensor(mask, _builder) - return semantic.store(pointer, value, mask, _builder) + return semantic.store(pointer, value, mask, eviction_policy, _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 15a5cb648..2d137b904 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -648,6 +648,18 @@ def cast(input: tl.tensor, # ===----------------------------------------------------------------------===// +def _parse_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], @@ -684,14 +696,7 @@ def load(ptr: tl.tensor, raise ValueError(f"Cache modifier {cache_modifier} not supported") # eviction policy - eviction = ir.EVICTION_POLICY.NORMAL # default - if eviction_policy: - if eviction_policy == "evict_last": - eviction = ir.EVICTION_POLICY.EVICT_LAST - elif eviction_policy == "evict_first": - eviction = ir.EVICTION_POLICY.EVICT_FIRST - else: - raise ValueError(f"Eviction policy {eviction_policy} not supported") + eviction = _parse_eviction_policy(eviction_policy) if ptr.type.is_block(): shape = ptr.type.get_block_shapes() @@ -721,6 +726,7 @@ def load(ptr: tl.tensor, def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], + eviction_policy: str, builder: ir.builder) -> tl.tensor: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) @@ -735,14 +741,15 @@ def store(ptr: tl.tensor, elt_ty_ptr = tl.int8 ptr_ty = tl.pointer_type(elt_ty_ptr, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) - + # eviction policy + eviction = _parse_eviction_policy(eviction_policy) # cast to target data-type val = cast(val, elt_ty, builder) if not mask: - return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) + return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void) if not mask.type.scalar.is_bool(): raise ValueError("Mask must have boolean scalar type") - return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void) ######### # atomic @@ -897,27 +904,31 @@ def atomic_xchg(ptr: tl.tensor, # ===----------------------------------------------------------------------===// -def dot(lhs: tl.tensor, - rhs: tl.tensor, +def dot(a: tl.tensor, + b: tl.tensor, + trans_a: bool, + trans_b: bool, allow_tf32: bool, builder: ir.builder) -> tl.tensor: - assert lhs.type.is_block() and rhs.type.is_block() - assert len(lhs.shape) == 2 and len(rhs.shape) == 2 - assert lhs.shape[-1] == rhs.shape[0] - assert lhs.shape[0] >= 16 and lhs.shape[1] >= 16 and rhs.shape[1] >= 16,\ + in_a = 1 if not trans_a else 0 + in_b = 1 if trans_b else 0 + assert a.type.is_block() and b.type.is_block() + assert len(a.shape) == 2 and len(b.shape) == 2 + assert a.shape[in_a] == b.shape[in_b] + assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\ "small blocks not supported!" - if lhs.type.scalar.is_int(): + if a.type.scalar.is_int(): _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 else: _0 = builder.get_float32(0) ret_scalar_ty = tl.float32 - M = lhs.type.shape[0] - N = rhs.type.shape[1] + M = a.type.shape[in_a ^ 1] + N = b.type.shape[in_b ^ 1] _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), - ret_ty) + ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32) + return tl.tensor(ret, ret_ty) # ===----------------------------------------------------------------------===// diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py new file mode 100644 index 000000000..eb9b40c60 --- /dev/null +++ b/python/tutorials/06-fused-attention.py @@ -0,0 +1,198 @@ +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, K, V, + TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kk, stride_kn, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_qm = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + q = tl.load(q_ptrs) + for start_n in range(0, start_qm + 1): + # -- compute qk ---- + k = tl.load(k_ptrs) + qk = tl.dot(q, k) + qk += tl.where(offs_m[:, None] >= (start_n * BLOCK_N + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + p = p.to(tl.float16) + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs) + acc += tl.dot(p, v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + # r_ptrs += BLOCK_N + l_i = l_i_new + m_i = m_i_new + + start_qm = tl.program_id(0) + offs_m = start_qm * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_out = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_out + tl.store(out_ptrs, acc) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v): + BLOCK = 128 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-2] + assert Lq == Lk + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + _fwd_kernel[grid]( + q, k, v, + tmp, L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=64, num_warps=4, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + return o + + +attention = _attention.apply + + +@pytest.mark.parametrize('Z, H, N_CTX, D_MODEL', [(2, 3, 1024, 64)]) +def test_op(Z, H, N_CTX, D_MODEL, dtype=torch.float16): + torch.manual_seed(20) + q = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + k = .5 * torch.randn((Z, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True) + v = .5 * torch.randn((Z, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + # triton implementation + tri_out = attention(q, k, v) + # reference implementation + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + ref_qk = torch.matmul(q, k) + for z in range(Z): + for h in range(H): + ref_qk[:, :, M == 0] = float("-inf") + ref_qk = torch.softmax(ref_qk, dim=-1) + ref_out = torch.matmul(ref_qk, v) + # compare + triton.testing.assert_almost_equal(ref_out, tri_out) + + +try: + from flash_attn.flash_attn_interface import flash_attn_func + HAS_FLASH = True +except BaseException: + HAS_FLASH = False + +BATCH, N_HEADS, N_CTX, D_HEAD = 4, 64, 2048, 64 +# vary batch size for fixed heads / seq +batch_bench = triton.testing.Benchmark( + x_names=['BATCH'], + x_vals=[2**i for i in range(0, 8)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-seq{N_CTX}-head{N_HEADS}-d{D_HEAD}', + args={'H': N_HEADS, 'N_CTX': N_CTX, 'D_MODEL': D_HEAD, 'dtype': torch.float16} +) +# vary seq length for fixed head and batch=4 +seq_bench = triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 16)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}', + args={'H': D_HEAD, 'BATCH': BATCH, 'D_MODEL': D_HEAD, 'dtype': torch.float16} +) + + +@triton.testing.perf_report([batch_bench, seq_bench]) +def bench_flash_attention(BATCH, H, N_CTX, D_MODEL, provider, dtype=torch.float16, device="cuda"): + warmup = 25 + rep = 500 + if provider == "triton": + q = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, D_MODEL, N_CTX), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, N_CTX, D_MODEL), dtype=dtype, device="cuda", requires_grad=True) + fn = lambda: attention(q, k, v) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + if provider == "flash": + lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) + cu_seqlens[1:] = lengths.cumsum(0) + qkv = torch.randn((BATCH * N_CTX, 3, H, D_MODEL), dtype=dtype, device=device, requires_grad=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) + return ms + + +bench_flash_attention.run(save_path='.', print_data=True)