From d8d6b715c878b9156bf42130908a30de7d0bdaf7 Mon Sep 17 00:00:00 2001 From: daadaada Date: Mon, 21 Jun 2021 14:25:13 +0800 Subject: [PATCH] [CODEGEN] Performance improvement on A100 (#125) Improved codegen for the Ampere GPUs. * Make the layout pass recognize the multistage pipelined pattern. * Now the pipeline pass can automate the multistage pipelining transformation. * Remove extra barriers (from the prefetch pass & WAR) on Ampere. * Update the code generator (generator.cc) to make Triton generate n-buffered shared memory loads/stores. --- include/triton/codegen/analysis/layout.h | 14 + include/triton/codegen/pass.h | 2 +- include/triton/codegen/selection/generator.h | 4 + include/triton/codegen/transform/membar.h | 9 +- include/triton/codegen/transform/pipeline.h | 4 +- include/triton/codegen/transform/prefetch.h | 5 + include/triton/ir/instructions.h | 1 + include/triton/ir/print.h | 8 +- lib/codegen/analysis/layout.cc | 101 +++++- lib/codegen/pass.cc | 9 +- lib/codegen/selection/generator.cc | 322 +++++++++++++++---- lib/codegen/transform/membar.cc | 78 ++++- lib/codegen/transform/pipeline.cc | 255 +++++++++++++-- lib/codegen/transform/prefetch.cc | 26 +- lib/driver/module.cc | 2 +- lib/ir/print.cc | 48 +++ python/src/triton.cc | 9 +- python/test/test_blocksparse.py | 2 +- python/test/test_matmul.py | 87 ++--- python/triton/code_gen.py | 33 +- python/triton/testing.py | 10 + 21 files changed, 855 insertions(+), 174 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 14c760bf1..0cf7faf2e 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -141,10 +141,19 @@ struct double_buffer_info_t { ir::phi_node* phi; }; +struct N_buffer_info_t { + std::vector firsts; // not necessarily ordered as input order + ir::value* latch; + ir::phi_node* phi; + std::map firsts_idx; +}; + +// abstract for dot and coresponding smem values class shared_layout: public data_layout { private: static bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); static void extract_double_bufferable(ir::value *v, std::shared_ptr& res); + static void extract_N_bufferable(ir::value *v, std::shared_ptr& res, int &prev_stages); public: shared_layout(data_layout *arg, @@ -158,6 +167,10 @@ public: size_t get_size() { return size_; } ir::type* get_type() { return ty_; } double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); } + N_buffer_info_t* get_N_buffer() { return N_buffer_.get(); } + int get_num_stages() const; + size_t get_per_stage_size() const { return size_ / get_num_stages(); } + size_t get_per_stage_elements() const; size_t get_num_per_phase() { return num_per_phase_; } ir::value* hmma_dot_a() { return hmma_dot_a_; } ir::value* hmma_dot_b() { return hmma_dot_b_; } @@ -169,6 +182,7 @@ private: size_t size_; ir::type *ty_; std::shared_ptr double_buffer_; + std::shared_ptr N_buffer_; size_t num_per_phase_; ir::value* hmma_dot_a_; ir::value* hmma_dot_b_; diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h index c1b67372c..94ba25a0f 100644 --- a/include/triton/codegen/pass.h +++ b/include/triton/codegen/pass.h @@ -21,7 +21,7 @@ namespace codegen{ // TODO: // There should be a proper pass manager there! -void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, +void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps, int num_stages, driver::module*& mod, driver::kernel*& ker, size_t& shared_mem); diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index c4a5abed6..2a65ba27c 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -223,6 +223,10 @@ private: std::map shoffs_; std::map> idxs_; std::map> vals_; + /// idx for multi-stage pipeline + std::map read_smem_idx_; + std::map write_smem_idx_; + /// triton bb -> llvm bb std::map bbs_; std::map> ords_; diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index d35bd10ba..21145a4fe 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -32,6 +32,8 @@ class shared_layout; namespace transform{ +class prefetch; + class membar { private: typedef std::pair interval_t; @@ -40,6 +42,7 @@ private: private: bool intersect(const val_set_t &X, const val_set_t &Y); + bool check_safe_war(ir::instruction* i); int group_of(triton::ir::value *i, std::vector &async_write); bool intersect_with(analysis::shared_layout* a_layout, analysis::shared_layout* b_layout); val_set_t intersect_with(const val_set_t& as, const val_set_t& bs); @@ -47,14 +50,16 @@ private: std::set &safe_war, bool &inserted, ir::builder &builder); public: - membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, target* tgt): - liveness_(liveness), layouts_(layouts), alloc_(alloc), tgt_(tgt) {} + membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc, + transform::prefetch *prefetch, target* tgt): + liveness_(liveness), layouts_(layouts), alloc_(alloc), prefetch_(prefetch), tgt_(tgt) {} void run(ir::module &mod); private: analysis::liveness *liveness_; analysis::layouts *layouts_; analysis::allocation *alloc_; + transform::prefetch *prefetch_; target* tgt_; }; diff --git a/include/triton/codegen/transform/pipeline.h b/include/triton/codegen/transform/pipeline.h index 4d0650529..35472de04 100644 --- a/include/triton/codegen/transform/pipeline.h +++ b/include/triton/codegen/transform/pipeline.h @@ -14,11 +14,13 @@ namespace transform { class pipeline { public: - pipeline(bool has_copy_async): has_copy_async_(has_copy_async) {} + pipeline(bool has_copy_async, int num_stages) + : has_copy_async_(has_copy_async), num_stages_(num_stages) {} void run(ir::module &module); private: bool has_copy_async_; + int num_stages_; }; } // namespace transform diff --git a/include/triton/codegen/transform/prefetch.h b/include/triton/codegen/transform/prefetch.h index 01fad8875..6843b5463 100644 --- a/include/triton/codegen/transform/prefetch.h +++ b/include/triton/codegen/transform/prefetch.h @@ -1,9 +1,12 @@ #ifndef TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H #define TRITON_INCLUDE_TRITON_CODEGEN_TRANSFORM_PREFETCH_H +#include + // forward dclaration namespace triton::ir{ class module; +class value; } namespace triton::codegen { @@ -13,9 +16,11 @@ class target; namespace triton::codegen::transform { class prefetch { target* tgt_; + std::set prefetched_vals_; public: prefetch(target *tgt) : tgt_(tgt) {} void run(ir::module &module); + bool is_prefetched(ir::value* v) { return prefetched_vals_.find(v) != prefetched_vals_.end(); } }; } diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 2e3b1e9ed..cf1f295dc 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -832,6 +832,7 @@ public: static async_wait_inst* create(context &ctx, int N, const std::string &name = "", instruction *next = nullptr); int get_N() { return N_; } + void set_N(int n) { N_ = n; } private: int N_; diff --git a/include/triton/ir/print.h b/include/triton/ir/print.h index 471948d5f..6dbf2fe02 100644 --- a/include/triton/ir/print.h +++ b/include/triton/ir/print.h @@ -1,5 +1,3 @@ -#pragma once - #ifndef _TRITON_IR_PRINT_H_ #define _TRITON_IR_PRINT_H_ @@ -9,8 +7,14 @@ namespace triton{ namespace ir{ class module; +class function; +class basic_block; +class instruction; void print(module &mod, std::ostream& os); +void print(function &func, std::ostream& os); +void print(basic_block &bb, std::ostream& os); +void print(instruction &instr, std::ostream& os); } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index dc280aba5..2b3c3d1ac 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -7,6 +7,7 @@ #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/utils.h" +// #include "triton/ir/type.h" namespace triton{ namespace codegen{ @@ -273,6 +274,81 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr(v) || + dynamic_cast(v)) + return true; + else + return false; +} + +/// param: +/// value_1: next_value +static bool is_multistage_pipe_phi(ir::phi_node* phi, ir::basic_block* bb0, ir::basic_block* bb1, + std::vector& values_0, ir::value*& value_1) { + ir::value* next = phi; + while (auto cphi = dynamic_cast(next)) { + // smem from previous bb & phi/smem from current bb + ir::value* c0 = cphi->get_incoming_value(0); + ir::value* c1 = cphi->get_incoming_value(1); + ir::basic_block *cbb0 = cphi->get_incoming_block(0); + ir::basic_block *cbb1 = cphi->get_incoming_block(1); + + if (is_smem(c0)) { + assert(cbb0 == bb0); + values_0.push_back(c0); + if (auto phi1 = dynamic_cast(c1)) { + next = phi1; + continue; + } else { + if (is_smem(c1)) { + value_1 = c1; + assert(cbb1 == bb1); + return true; + } else { + return false; + } + } + } else + return false; + } +} + +void shared_layout::extract_N_bufferable(ir::value *v, std::shared_ptr &res, int &prev_stages) { + auto* phi = dynamic_cast(v); + // if the phi node is nested + if (!phi) + return; + + ir::basic_block *bb0 = phi->get_incoming_block(0); + ir::basic_block *bb1 = phi->get_incoming_block(1); + + std::vector values_0; + ir::value* value_1; + + if (!is_multistage_pipe_phi(phi, bb0, bb1, values_0, value_1)) + return; + + // double-buffer is a special case + if (values_0.size() == 1) + return; + + // compute original values_0 input order + std::map order; + int idx = 0; + for (ir::instruction* instr : *bb0) { + if (std::find(values_0.begin(), values_0.end(), instr) != values_0.end()) + order[static_cast(instr)] = idx++; + } + assert(order.size() == values_0.size() && "order size incorrect"); + + int curr_stages = values_0.size() + 1; + if (curr_stages > prev_stages) { + res.reset(new N_buffer_info_t{values_0, value_1, phi, order}); + prev_stages = curr_stages; + } +} + shared_layout::shared_layout(data_layout *arg, const std::vector& axes, @@ -284,9 +360,15 @@ shared_layout::shared_layout(data_layout *arg, size_ = 0; arg_layout_ = arg; + // N-stage buffering + int prev_stages = 0; + for (ir::value *v : values) + extract_N_bufferable(v, N_buffer_, prev_stages); + // double-buffering - for(ir::value *v: values) - extract_double_bufferable(v, double_buffer_); + if (!N_buffer_) + for(ir::value *v: values) + extract_double_bufferable(v, double_buffer_); // order std::vector arg_order = arg ? arg->get_order() : std::vector{0}; @@ -311,8 +393,22 @@ shared_layout::shared_layout(data_layout *arg, size_ *= s; if(double_buffer_) size_ *= 2; + if (N_buffer_) { + size_ *= (N_buffer_->firsts.size() + 1); + } } +int shared_layout::get_num_stages() const { + if (double_buffer_) + return 2; + if (N_buffer_) + return N_buffer_->firsts.size() + 1; + return 1; +} + +size_t shared_layout::get_per_stage_elements() const { + return get_per_stage_size()/(ty_->get_primitive_size_in_bits()/8); +} /* -------------------------------- * * ---- Layouts Inference Pass ---- * @@ -403,7 +499,6 @@ void layouts::run(ir::module &mod) { for(const auto& x: values_) create(x.first, x.second); - // create temporaries size_t id = values_.size(); ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index d0d2f34ec..d3b78f28f 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -26,7 +26,7 @@ namespace codegen { // TODO: // There should be a proper pass manager there! -void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, +void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, int num_stages, driver::module *&mod, driver::kernel *&ker, size_t &shared_mem) { // generate llvm code llvm::LLVMContext ctx; @@ -39,26 +39,27 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, codegen::analysis::align align; codegen::analysis::axes axes; codegen::transform::cts cts(cts_use_async); - codegen::transform::pipeline pipeline(cts_use_async); + codegen::transform::pipeline pipeline(cts_use_async, num_stages); codegen::transform::disassociate disassociate; codegen::analysis::layouts layouts(&axes, &align, num_warps, target.get()); codegen::analysis::liveness liveness(&layouts); codegen::analysis::swizzle swizzle(&layouts, target.get()); codegen::analysis::allocation allocation(&liveness); - codegen::transform::membar barriers(&liveness, &layouts, &allocation, target.get()); codegen::transform::dce dce; codegen::transform::peephole peephole(target.get(), &layouts); // codegen::transform::reassociate reassociate; codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::prefetch prefetch_s(target.get()); + codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target.get()); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); // run passes dce.run(ir); peephole.run(ir); dce.run(ir); + // ir::print(ir, std::cout); pipeline.run(ir); dce.run(ir); - //ir::print(ir, std::cout); + // ir::print(ir, std::cout); disassociate.run(ir); dce.run(ir); align.run(ir); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 0791ee6fc..e93fe7895 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -212,18 +212,41 @@ void generator::visit_value(ir::value* v) { return; if(v->get_type()->is_block_ty()){ if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){ - auto double_buffer = layout->get_double_buffer(); + analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer(); + analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer(); + // offset Value *offset = nullptr; - if(double_buffer && v == double_buffer->phi) - offset = shared_off_[layout]; // base pointer Value *ptr = shared_ptr_[layout]; - if(double_buffer && v == double_buffer->latch) - ptr = shared_next_ptr_[layout]; - else if(double_buffer && v == double_buffer->first) - ptr = shared_pre_ptr_[layout]; + + if (n_buffer) { + // ptr = base (shared_ptr_[layout]) + smem_idx * size + // read_smem_idx + if (v == n_buffer->phi) { + ptr = shared_ptr_[layout]; + } + // write_smem_idx + if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) { + int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v); + int elements = write_smem_idx * layout->get_per_stage_elements(); + ptr = gep(shared_pre_ptr_[layout], i32(elements)); + } else if (v == n_buffer->latch) { + Value* write_smem_idx = write_smem_idx_[layout]; + Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements())); + ptr = gep(shared_pre_ptr_[layout], elements); + } + } else if (double_buffer) { + if(v == double_buffer->phi) + offset = shared_off_[layout]; + if(v == double_buffer->latch) + ptr = shared_next_ptr_[layout]; + else if(v == double_buffer->first) + ptr = shared_pre_ptr_[layout]; + } // else do nothing + // what visit_dot & vist_cts & ... see shmems_[v] = ptr; + // now only latches have offset (PHINode), only used by finalize_share_layout() shoffs_[v] = offset; } } @@ -1223,24 +1246,21 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va * \brief Code Generation for `mma.16816` (A100) */ //TODO: clean-up -void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { - const auto& shapes = dot->get_type()->get_block_shapes(); - +void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { + const std::vector& shapes = C->get_type()->get_block_shapes(); std::map, std::vector> fcs; - - for(indices_t idx: idxs_.at(dot)){ + for(indices_t idx: idxs_.at(C)){ std::vector key(idx.size() - 2); std::copy(idx.begin() + 2, idx.end(), key.begin()); fcs[key].push_back(vals_[D][idx]); }; - auto shape_a = A->get_type()->get_block_shapes(); auto shape_b = B->get_type()->get_block_shapes(); auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); - analysis::mma_layout* layout = layouts_->get(dot)->to_mma(); - analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(dot->get_operand(0)); - analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(dot->get_operand(1)); + analysis::mma_layout* layout = layouts_->get(C)->to_mma(); + analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); + analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); bool is_a_row = ord_a[0] == 1; bool is_b_row = ord_b[0] == 1; std::string a_trans = is_a_row ? "" : ".trans"; @@ -1264,8 +1284,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir int vec_a = 8; int vec_b = 8; - - Type *fp32_ty = f32_ty; Type *fp16x2_ty = vec_ty(f16_ty, 2); Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); @@ -1276,7 +1294,6 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir std::map, std::pair> ha; std::map, Value*> hb; - BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); builder_->SetInsertPoint(FirstBB->getTerminator()); @@ -1339,66 +1356,167 @@ void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir "{$0, $1, $2, $3}, " "{$4, $5, $6, $7}, " "{$8, $9}, " - "{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", false); + "{$10, $11, $12, $13};", + "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true); + unsigned num_rep_0 = shapes[0] / layout->spt(0); unsigned num_rep_1 = shapes[1] / layout->spt(1); - for(unsigned K = 0; K < NK; K += 16) - for(unsigned m = 0; m < num_rep_0; m++) - for(unsigned n = 0; n < num_rep_1; n++){ - if(ha.find({m, K}) == ha.end()){ - Value* ptra = ptrs_a[(is_a_row ? K/16 : m) % num_ptr_a]; + + // create mma & unpack result + auto call_mma = [&](unsigned m, unsigned n, unsigned K) { + unsigned cols_per_thread = num_rep_0 * 2; + std::vector idx = { + (m*2 + 0) + (n*2 + 0)*cols_per_thread, + (m*2 + 0) + (n*2 + 1)*cols_per_thread, + (m*2 + 1) + (n*2 + 0)*cols_per_thread, + (m*2 + 1) + (n*2 + 1)*cols_per_thread + }; + Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second, + hb[{n, K}], hb[{n, K+8}], + fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); + fc[idx[0]] = extract_val(nc, std::vector{0}); + fc[idx[1]] = extract_val(nc, std::vector{1}); + fc[idx[2]] = extract_val(nc, std::vector{2}); + fc[idx[3]] = extract_val(nc, std::vector{3}); + }; + + ir::phi_node* phiA = dynamic_cast(A); + ir::phi_node* phiB = dynamic_cast(B); + + auto register_lds = + [&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { + if (K <= 8 && is_prefetch) { + ir::basic_block* inc_block = phiA->get_incoming_block(inc); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); + } else + vals[{m, K}] = {val0, val1}; + }; + + auto register_lds2 = + [&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) { + if (K <= 8 && is_prefetch) { + ir::basic_block* inc_block = phiA->get_incoming_block(inc); + lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block)); + } else + vals[{m, K}] = val; + }; + + auto load_a = [&](int m, int K, int inc, bool is_prefetch) { + int offidx = (is_a_row ? K/16 : m) % num_ptr_a; + Value* ptra; + if(K == 0 && is_prefetch){ + if(inc == 0) + ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); + else + ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); + } + else + ptra = ptrs_a[offidx]; int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K; InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", "=r,=r,=r,=r,r", false); + "{$0, $1, $2, $3}, [$4 + " + + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", + "=r,=r,=r,=r,r", true); Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra}); + if(K == 0 && inc == 1 && is_prefetch) + prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa); Value *ha0 = extract_val(haa, std::vector{0}); Value *ha1 = extract_val(haa, std::vector{1}); Value *ha2 = extract_val(haa, std::vector{2}); Value *ha3 = extract_val(haa, std::vector{3}); - ha[{m, K}] = std::make_pair(ha0, ha1); - ha[{m, K+8}] = std::make_pair(ha2, ha3); - } - if(hb.find({n, K})==hb.end()){ - Value* ptrb = ptrs_b[(is_b_row ? n : K/16) % num_ptr_b]; + register_lds(ha, m, K, inc, ha0, ha1, is_prefetch); + register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch); + }; + + auto load_b = [&](int n, int K, int inc, bool is_prefetch) { + int offidx = (is_b_row ? n : K/16) % num_ptr_b; + Value* ptrb; + if(K == 0 && is_prefetch){ + if(inc == 0) + ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); + else + ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); + } + else + ptrb = ptrs_b[offidx]; int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8); InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 " - "{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", "=r,=r,=r,=r,r", false); + "{$0, $1, $2, $3}, [$4 + " + + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", + "=r,=r,=r,=r,r", true); Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb}); + if(K == 0 && inc == 1 && is_prefetch) + prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb); Value *hb0 = extract_val(hbb, std::vector{0}); Value *hb1 = extract_val(hbb, std::vector{1}); Value *hb2 = extract_val(hbb, std::vector{2}); Value *hb3 = extract_val(hbb, std::vector{3}); - hb[{n, K}] = hb0; - hb[{n+1, K}] = hb2; - hb[{n, K+8}] = hb1; - hb[{n+1, K+8}] = hb3; - } - unsigned cols_per_thread = num_rep_0 * 2; - std::vector idx = { - (m*2 + 0) + (n*2 + 0)*cols_per_thread, - (m*2 + 0) + (n*2 + 1)*cols_per_thread, - (m*2 + 1) + (n*2 + 0)*cols_per_thread, - (m*2 + 1) + (n*2 + 1)*cols_per_thread - }; - Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second, - hb[{n, K}], hb[{n, K+8}], - fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); - fc[idx[0]] = extract_val(nc, std::vector{0}); - fc[idx[1]] = extract_val(nc, std::vector{1}); - fc[idx[2]] = extract_val(nc, std::vector{2}); - fc[idx[3]] = extract_val(nc, std::vector{3}); - } + register_lds2(hb, n, K, inc, hb0, is_prefetch); + register_lds2(hb, n+1, K, inc, hb2, is_prefetch); + register_lds2(hb, n, K+8, inc, hb1, is_prefetch); + register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch); + }; + if (C->is_prefetched()) { + // create phis + builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); + for(unsigned m = 0; m < num_rep_0; m++){ + ha[{m, 0}].first = phi(fp16x2_ty, 2); + ha[{m, 0}].second = phi(fp16x2_ty, 2); + ha[{m, 8}].first = phi(fp16x2_ty, 2); + ha[{m, 8}].second = phi(fp16x2_ty, 2); + } + for(unsigned n = 0; n < num_rep_1; n+=2){ + hb[{n, 0}] = phi(fp16x2_ty, 2); + hb[{n+1, 0}] = phi(fp16x2_ty, 2); + hb[{n, 8}] = phi(fp16x2_ty, 2); + hb[{n+1, 8}] = phi(fp16x2_ty, 2); + } + // insert prefetched lds at the end of loop header + builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); + for(unsigned m = 0; m < num_rep_0; m++) + load_a(m, 0, 0, true); + for(unsigned n = 0; n < num_rep_1; n+=2) + load_b(n, 0, 0, true); + // update accumulators + builder_->SetInsertPoint(CurrBB); + for(unsigned K = 0; K < NK; K += 16){ + int NEXTK = (K + 16) % NK; + // prefetch A + for(unsigned m = 0; m < num_rep_0; m++) + load_a(m, NEXTK, 1, true); + // prefetch B + for(unsigned n = 0; n < num_rep_1; n+=2) + load_b(n, NEXTK, 1, true); + // tensor core ops + for(unsigned m = 0; m < num_rep_0; m++) + for(unsigned n = 0; n < num_rep_1; n++){ + call_mma(m, n, K); + } + } + } + else{ + for(unsigned K = 0; K < NK; K += 16) + for(unsigned m = 0; m < num_rep_0; m++) + for(unsigned n = 0; n < num_rep_1; n++){ + if(ha.find({m, K}) == ha.end()) + load_a(m, K, 0, false); + if(hb.find({n, K})==hb.end()) + load_b(n, K, 0, false); + call_mma(m, n, K); + } + } // write back unsigned i = 0; - for(indices_t idx: idxs_.at(dot)){ + for(indices_t idx: idxs_.at(C)){ std::vector key(idx.size() - 2); std::copy(idx.begin() + 2, idx.end(), key.begin()); if(i >= fcs.at(key).size()) i = 0; - vals_[dot][idx] = fcs.at(key)[i++]; + vals_[C][idx] = fcs.at(key)[i++]; }; } @@ -2252,8 +2370,35 @@ void generator::visit_layout_scanline(analysis::scanline_layout* layout) { void generator::visit_layout_shared(analysis::shared_layout* layout) { Type* ty = cvt(layout->get_type()); PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace()); - // double-buffered - if(layout->get_double_buffer()) { + if (layout->get_N_buffer()) { + // create pointers + shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); + shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty); + + BasicBlock *current = builder_->GetInsertBlock(); + + auto info = *layout->get_N_buffer(); + ir::phi_node *phi = info.phi; + BasicBlock *parent = bbs_.at(phi->get_parent()); + if(parent->empty()) + builder_->SetInsertPoint(parent); + else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) { + builder_->SetInsertPoint(&*parent->getFirstNonPHI()); + } else + builder_->SetInsertPoint(parent); + + // create smem_idx + read_smem_idx_[layout] = phi(i32_ty, 2); + write_smem_idx_[layout] = phi(i32_ty, 2); + + // create pointers + // ptr of the current iteration + shared_ptr_[layout] = phi(ptr_ty, 2); + // ptr of the next iteration + shared_next_ptr_[layout] = phi(ptr_ty, 2); + + builder_->SetInsertPoint(current); + } else if(layout->get_double_buffer()) { BasicBlock *current = builder_->GetInsertBlock(); auto info = *layout->get_double_buffer(); ir::phi_node *phi = info.phi; @@ -2269,8 +2414,7 @@ void generator::visit_layout_shared(analysis::shared_layout* layout) { shared_off_[layout] = phi(i32_ty, 2); shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr"); builder_->SetInsertPoint(current); - } - else{ + } else{ size_t offset = alloc_->offset(layout); shared_ptr_[layout] = gep(shmem_, i32(offset)); shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty); @@ -2354,7 +2498,67 @@ void generator::init_idx(ir::value *v) { } void generator::finalize_shared_layout(analysis::shared_layout *shared) { - if(shared->get_double_buffer()) { + if (auto n_buffer = shared->get_N_buffer()) { + // if (*_smem_idx == #stages-1) { + // *_smem_idx = 0; + // } else *_smem_idx++; + auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) { + // insert point + Value *idx = smem_idx[shared]; + builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator()); + Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1)); + PHINode *_ret = phi(i32_ty, 2); + Instruction *then_term = nullptr; + Instruction *else_term = nullptr; + Instruction *dummy = builder_->CreateRet(nullptr); + llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr); + dummy->removeFromParent(); + builder_->SetInsertPoint(then_term); + Value *zero_smem_idx = i32(0); + builder_->SetInsertPoint(else_term); + Value *inc_smem_idx = add(idx, i32(1)); + builder_->SetInsertPoint(_ret->getParent()); + _ret->addIncoming(zero_smem_idx, then_term->getParent()); + _ret->addIncoming(inc_smem_idx, else_term->getParent()); + // update ir::bb -> llvm::bb mapping + bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock(); + // idx = init_stage; + // loop: ... + if (auto idx_phi = llvm::dyn_cast(smem_idx[shared])) { + idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0))); + idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1))); + } else + throw std::runtime_error("Should be PHINode"); + }; + + // read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2 + finalize_smem_idx(read_smem_idx_, 2); + finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1); + + // finalize pointers + ir::phi_node *pn = n_buffer->phi; + BasicBlock *header = bbs_.at(pn->get_incoming_block(0)); + BasicBlock *loop = bbs_.at(pn->get_incoming_block(1)); + // %curr_ptr = phi %shared_pre_ptr, %next_ptr + // %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size)) + if (auto curr_ptr = dyn_cast(shared_ptr_[shared])) { + curr_ptr->addIncoming(shared_pre_ptr_[shared], header); + curr_ptr->addIncoming(shared_next_ptr_[shared], loop); + } else + throw std::runtime_error("Should be PHINode"); + + BasicBlock *current = builder_->GetInsertBlock(); + builder_->SetInsertPoint(header->getTerminator()); + Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements())); + builder_->SetInsertPoint(current->getTerminator()); + + assert(isa(shared_next_ptr_[shared])); + static_cast(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header); + + Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements())); + Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset); + static_cast(shared_next_ptr_[shared])->addIncoming(next_ptr, loop); + } else if(shared->get_double_buffer()) { auto info = *shared->get_double_buffer(); ir::phi_node *phi = info.phi; PHINode *ptr = (PHINode*)shmems_[phi]; diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index de2552c32..96249bcd5 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -4,6 +4,7 @@ #include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/transform/membar.h" +#include "triton/codegen/transform/prefetch.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -20,9 +21,14 @@ namespace transform{ int membar::group_of(ir::value* v, std::vector &async_write) { if(ir::phi_node* phi = dynamic_cast(v)){ analysis::shared_layout* layout = layouts_->get(v)->to_shared(); - analysis::double_buffer_info_t* info = layout->get_double_buffer(); - if(info) + if (analysis::double_buffer_info_t* info = layout->get_double_buffer()) return group_of(info->first, async_write); + else if (analysis::N_buffer_info_t* info = layout->get_N_buffer()) { + if (v == info->phi) + return group_of(info->firsts[0], async_write); + else // prefetched value + return group_of(info->firsts[1], async_write); + } std::vector groups(phi->get_num_operands()); std::transform(phi->op_begin(), phi->op_end(), groups.begin(), [&](ir::value* v){ return group_of(v, async_write);}); return *std::max_element(groups.begin(), groups.end()); @@ -69,12 +75,31 @@ membar::val_set_t membar::intersect_with(const val_set_t& as, const val_set_t& b return ret; } +bool membar::check_safe_war(ir::instruction* i) { + bool is_i_shared_block = i->get_type()->is_block_ty() && + layouts_->get(i)->to_shared(); + bool is_i_double_buffered = is_i_shared_block && + layouts_->get(i)->to_shared()->get_double_buffer(); + bool is_i_n_buffered = is_i_shared_block && + layouts_->get(i)->to_shared()->get_N_buffer(); + + if (is_i_double_buffered || is_i_n_buffered) { + // with async copy & prefetch_s disabled, WARs are not safe + if (dynamic_cast(i) && !prefetch_->is_prefetched(i)) + return false; + else + return true; + } + return false; +} + void membar::transfer(ir::basic_block *block, val_vec_t& async_write, val_set_t& sync_write, val_set_t& sync_read, std::set& safe_war, bool& inserted, ir::builder& builder) { + std::vector async_waits; ir::basic_block::inst_list_t instructions = block->get_inst_list(); for(ir::instruction *i: instructions){ if(dynamic_cast(i)) @@ -105,18 +130,14 @@ void membar::transfer(ir::basic_block *block, async_wait = (ir::async_wait_inst*)builder.create_async_wait(async_write.size() - 1 - N); barrier = (ir::barrier_inst*)builder.create_barrier(); inserted = true; + async_waits.push_back(async_wait); } } // RAW, WAR - bool is_i_double_buffered = i->get_type()->is_block_ty() && - layouts_->get(i)->to_shared() && - layouts_->get(i)->to_shared()->get_double_buffer(); + bool is_safe_war = check_safe_war(i); // WAR barrier is not required when data is double-buffered - // TODO: how about other patterns, like WWAR? - if(!intersect_with(read, sync_write).empty() || - (!intersect_with({i}, sync_read).empty() && !is_i_double_buffered) || - // force WAR barrier on A100 - (!intersect_with({i}, sync_read).empty() && tgt_->as_nvidia()->sm() >= 80)){ + if(!intersect_with(read, sync_write).empty() || + (!intersect_with({i}, sync_read).empty() && !is_safe_war)) { builder.set_insert_point(i); barrier = (ir::barrier_inst*)builder.create_barrier(); inserted = true; @@ -132,7 +153,41 @@ void membar::transfer(ir::basic_block *block, sync_read.clear(); } sync_read.insert(read.begin(), read.end()); + } + // coalesce barriers + // fixme: to support more general cases + if (async_waits.size() == 2) { + // (aw N; bar; prefetch; aw N-1; bar; prefetch; => aw N-1; bar; 2*prefetch;) + for (int idx=0; idx to_erase; + ir::basic_block::inst_list_t instructions = block->get_inst_list(); + for(auto iter = instructions.begin(); iter != instructions.end(); ++iter){ + ir::instruction *i = *iter; + if (static_cast(first_async_wait) == i) { + // peak next 5 instructions + auto peak_iter = std::next(iter); + if (std::distance(peak_iter, instructions.end()) >= 5) { + auto first_bar = dynamic_cast(*peak_iter++); + auto first_pf = dynamic_cast(*peak_iter++); + auto second_async_wait = dynamic_cast(*peak_iter++); + auto second_bar = dynamic_cast(*peak_iter++); + auto second_pf = dynamic_cast(*peak_iter); + if (first_bar && first_pf && second_async_wait && second_bar && second_pf) { + int first_n = first_async_wait->get_N(); + int second_n = second_async_wait->get_N(); + to_erase.push_back(second_async_wait); + to_erase.push_back(second_bar); + first_async_wait->set_N(second_n); + } + } else + break; + for (ir::instruction *i : to_erase) + block->erase(i); + } + } + } } } @@ -144,7 +199,7 @@ void membar::run(ir::module &mod) { std::set safe_war; for(const auto& x: layouts_->get_all()){ analysis::shared_layout* layout = x.second->to_shared(); - if(!layout || !layout->get_double_buffer()) + if(!layout || !layout->get_double_buffer() || !layout->get_N_buffer()) continue; for(ir::value *v: layout->get_values()) if(v != layout->get_double_buffer()->phi){ @@ -153,7 +208,6 @@ void membar::run(ir::module &mod) { } for(ir::function *fn: mod.get_function_list()){ - // TODO: (dyan) we need DominatorTree here. std::vector rpo = ir::cfg::reverse_post_order(fn); std::map async_writes; std::map sync_writes; diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index 00520e9d6..39b058ae7 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -23,6 +23,60 @@ void recursive_deps(ir::value* v, ir::basic_block* block, std::vector& prev_phi_vals) { + ir::instruction* i = dynamic_cast(v); + if(!i) + return v; + if(ir::phi_node* phi = dynamic_cast(v)) { + if (prev_phi_vals.find(phi) == prev_phi_vals.end()) + throw std::runtime_error("Don't have that phi node\n"); + return prev_phi_vals.at(phi); + } + + std::vector new_ops; + for(ir::value* op: i->ops()){ + new_ops.push_back(rematerialize_vals(builder, op, prev_phi_vals)); + } + ir::instruction* ret = i->clone(); + for(size_t k = 0; k < new_ops.size(); k++) + ret->set_operand(k, new_ops[k]); + builder.insert(ret); + return ret; +} + +void get_induction_vars(ir::value* cond, std::set& phis) { + auto instr = dynamic_cast(cond); + for (auto op : instr->ops()) { + if (auto phi_op = dynamic_cast(op)) { + phis.insert(phi_op); + return; + } + if (dynamic_cast(op)) + get_induction_vars(op, phis); + } +} + +/// Returns phi_val if sees a phi node +ir::value* rematerialize_val(ir::builder& builder, ir::value* v, ir::value* phi_val) { + ir::instruction* i = dynamic_cast(v); + if(!i) + return v; + if(ir::phi_node* phi = dynamic_cast(v)) + return phi_val; + + std::vector new_ops; + for(ir::value* op: i->ops()){ + new_ops.push_back(rematerialize_val(builder, op, phi_val)); + } + ir::instruction* ret = i->clone(); + for(size_t k = 0; k < new_ops.size(); k++) + ret->set_operand(k, new_ops[k]); + builder.insert(ret); + return ret; +} + ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){ ir::instruction* i = dynamic_cast(v); if(!i) @@ -41,6 +95,28 @@ ir::value* rematerialize(ir::builder& builder, ir::value* v, size_t phi_idx){ return ret; } +/// moving the prev phi vals to the next iteration +void update_prev_phi_vals(ir::builder& builder, std::map& prev_phi_vals) { + for (auto& [phi, val] : prev_phi_vals) { + // TODO: handling nested phis + val = rematerialize_val(builder, phi->get_incoming_value(1), val); + } +} + +void finalize_iv_vals(ir::builder& builder, std::map& load_ivs, + std::map& next_load_ivs) { + for (auto& [phi, val] : load_ivs) { + if (auto new_phi = dynamic_cast(val)) { + ir::value* next_k = rematerialize_vals(builder, phi->get_incoming_value(1), load_ivs); + assert(new_phi->get_num_operands() == 1 && "should be incomplete phi"); + new_phi->add_incoming(next_k, phi->get_incoming_block(1)); + // cache next_k (to be used by next_mask) + next_load_ivs[phi] = next_k; + } else + throw std::runtime_error("must be phi"); + } +} + void pipeline::run(ir::module &mod) { // *Very* conservative heuristics for pre-fetching. // A load instruction can be pipelined if: @@ -60,6 +136,8 @@ void pipeline::run(ir::module &mod) { // do the pipelining std::vector new_loads; ir::builder &builder = mod.get_builder(); + const int num_stages = num_stages_; + std::vector>> preheader_loads; // Used to reorder loads for(auto info: to_pipeline){ ir::load_inst* load = info.first; ir::phi_node* ptr = info.second; @@ -70,40 +148,155 @@ void pipeline::run(ir::module &mod) { assert(block_br); assert(header_br); ir::type* ty = load->get_type(); - // pre-fetch first iteration - builder.set_insert_point(header->get_inst_list().back()); - ir::value* first_ptr = ptr->get_value_for_block(header); - ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes()); - ir::value* false_value; - if(auto* masked_load = dynamic_cast(load)){ - ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0); - ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0); - first_mask = builder.create_and(first_mask, remat_mask); - false_value = remat_false_value; + // multi-stage pipe + if (has_copy_async_ && num_stages > 2) { + ir::value* header_cond = header_br->get_cond(); + ir::value* block_cond = block_br->get_cond(); + // 1. collect induction variables + std::set induction_vars; + get_induction_vars(block_cond, induction_vars); + + std::vector first_ptrs(num_stages-1); + std::vector first_loads(num_stages-1); + std::vector first_masks(num_stages-1); + std::vector loop_conds(num_stages-1); + + std::map prev_phi_vals; + // initialize prev_phi_vals + // note: we assume that ptr & other values only depend on ptr & iv (phis) + // TODO: can we just add all phis here? + prev_phi_vals[ptr] = ptr->get_value_for_block(header); + for (ir::phi_node* iv : induction_vars) + prev_phi_vals[iv] = iv->get_value_for_block(header); + prev_phi_vals[ptr] = ptr->get_value_for_block(header); + + builder.set_insert_point(header->get_inst_list().back()); + first_ptrs[0] = ptr->get_value_for_block(header); + loop_conds[0] = header_cond; + first_masks[0] = builder.create_splat(loop_conds[0], ty->get_block_shapes()); + ir::value* false_value = nullptr; + if (auto* masked_load = dynamic_cast(load)) { + ir::value* remat_mask =rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals) ; + ir::value* remat_false_value = + rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals); + first_masks[0] = builder.create_and(first_masks[0], remat_mask); + false_value = remat_false_value; + } else + false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); + first_loads[0] = builder.create_masked_load(first_ptrs[0], first_masks[0], false_value); + + for (int stage = 1; stage < num_stages-1; ++stage) { + // mask is the loop condition of the previous iteration + loop_conds[stage] = rematerialize_vals(builder, block_cond, prev_phi_vals); + update_prev_phi_vals(builder, prev_phi_vals); + first_ptrs[stage] = rematerialize_vals(builder, ptr, prev_phi_vals); + first_masks[stage] = builder.create_splat(loop_conds[stage], ty->get_block_shapes()); + if (auto* masked_load = dynamic_cast(load)) { + ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), prev_phi_vals); + ir::value* remat_false_value = + rematerialize_vals(builder, masked_load->get_false_value_operand(), prev_phi_vals); + first_masks[stage] = builder.create_and(first_masks[stage], remat_mask); + false_value = remat_false_value; + } + first_loads[stage] = builder.create_masked_load(first_ptrs[stage], first_masks[stage], false_value); + } + + // create new phis for induction variables + builder.set_insert_point(block->get_first_non_phi()); + std::map load_ivs; + std::map next_load_ivs; + for (ir::phi_node* iv : induction_vars) { + ir::phi_node* pn = builder.create_phi(iv->get_type(), 2); + pn->add_incoming(prev_phi_vals[iv], header); + load_ivs[iv] = pn; + } + // add incoming for phis & update next_load_ivs + finalize_iv_vals(builder, load_ivs, next_load_ivs); + + // pre-fetch next iteration + builder.set_insert_point(block->get_inst_list().back()); + ir::value* next_ptr = ptr->get_value_for_block(block); + ir::value* next_mask = builder.create_splat( + rematerialize_vals(builder, block_cond, load_ivs), ty->get_block_shapes()); + if (auto* masked_load = dynamic_cast(load)) { + ir::value* remat_mask = rematerialize_vals(builder, masked_load->get_mask_operand(), next_load_ivs); + // TODO: false may depends on some other phi nodes + ir::value* remat_false_value = + rematerialize_vals(builder, masked_load->get_false_value_operand(), next_load_ivs); + next_mask = builder.create_and(next_mask, remat_mask); + false_value = remat_false_value; + } + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); + + + // phi node + ptr->set_incoming_value(0, first_ptrs.back()); + builder.set_insert_point(block->get_first_non_phi()); + // nested phis for load + std::vector new_load_phis(num_stages-1); + for (auto& pn : new_load_phis) + pn = builder.create_phi(ty, 2); + for (int i=0; iadd_incoming(first_loads[i], header); + new_load_phis[i]->add_incoming(new_load_phis[i+1], block); + } + new_load_phis.back()->add_incoming(first_loads.back(), header); + new_load_phis.back()->add_incoming(next_load, block); + load->replace_all_uses_with(new_load_phis.front()); + new_loads.push_back(new_load_phis.back()); + + // record first_loads to reorder them + preheader_loads.push_back({new_load_phis.front(), first_loads}); + } else { + // pre-fetch first iteration + builder.set_insert_point(header->get_inst_list().back()); + ir::value* first_ptr = ptr->get_value_for_block(header); + ir::value* first_mask = builder.create_splat(header_br->get_cond(), ty->get_block_shapes()); + ir::value* false_value; + if(auto* masked_load = dynamic_cast(load)){ + ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 0); + ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 0); + first_mask = builder.create_and(first_mask, remat_mask); + false_value = remat_false_value; + } + else + false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); + ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value); + // pre-fetch next iteration + builder.set_insert_point(block->get_inst_list().back()); + ir::value* next_ptr = ptr->get_value_for_block(block); + ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes()); + if(auto* masked_load = dynamic_cast(load)){ + ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1); + ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1); + next_mask = builder.create_and(next_mask, remat_mask); + false_value = remat_false_value; + } + ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); + // phi node + builder.set_insert_point(block->get_first_non_phi()); + ir::phi_node* new_load = builder.create_phi(ty, 2); + new_load->add_incoming(first_load, header); + new_load->add_incoming(next_load, block); + load->replace_all_uses_with(new_load); + new_loads.push_back(new_load); } - else - false_value = builder.create_splat(ir::undef_value::get(ty->get_scalar_ty()), ty->get_block_shapes()); - ir::value* first_load = builder.create_masked_load(first_ptr, first_mask, false_value); - // pre-fetch next iteration - builder.set_insert_point(block->get_inst_list().back()); - ir::value* next_ptr = ptr->get_value_for_block(block); - ir::value* next_mask = builder.create_splat(block_br->get_cond(), ty->get_block_shapes()); - if(auto* masked_load = dynamic_cast(load)){ - ir::value* remat_mask = rematerialize(builder, masked_load->get_mask_operand(), 1); - ir::value* remat_false_value = rematerialize(builder, masked_load->get_false_value_operand(), 1); - next_mask = builder.create_and(next_mask, remat_mask); - false_value = remat_false_value; - } - ir::value* next_load = builder.create_masked_load(next_ptr, next_mask, false_value); - // phi node - builder.set_insert_point(block->get_first_non_phi()); - ir::phi_node* new_load = builder.create_phi(ty, 2); - new_load->add_incoming(first_load, header); - new_load->add_incoming(next_load, block); - load->replace_all_uses_with(new_load); - new_loads.push_back(new_load); } + // try to reorder prefetched value from a0, a1, a2, ..., b0, b1, b2, ... to + // a0, b0, a1, b1, ... + if (!preheader_loads.empty()) { + ir::basic_block* header = preheader_loads.begin()->first->get_incoming_block(0); + builder.set_insert_point(header->get_inst_list().back()); + for (int i=1; i(iter->second.at(i)); + ir::instruction* moved_load = original_load->clone(); + builder.insert(moved_load); + original_load->replace_all_uses_with(moved_load); + } + } + } // try to move dot_inst after loads // for better overlap of io and compute diff --git a/lib/codegen/transform/prefetch.cc b/lib/codegen/transform/prefetch.cc index 96bcf57d7..6c6d28ad7 100644 --- a/lib/codegen/transform/prefetch.cc +++ b/lib/codegen/transform/prefetch.cc @@ -25,13 +25,12 @@ static void recursive_defs(ir::value *v, ir::basic_block *bb, std::vector to_prefetch; ir::for_each_instruction(mod, [&](ir::instruction *i) { if (auto *dot = dynamic_cast(i)) { - // Now only do prefetching when dot is fp16 & volta/turing - if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID || - tgt_->as_nvidia()->sm() >= 80) + // Now only do prefetching when dot is fp16 + if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::HalfTyID) return; auto *a = dynamic_cast(dot->get_operand(0)); auto *b = dynamic_cast(dot->get_operand(1)); @@ -56,16 +55,31 @@ void prefetch::run(ir::module &mod) { // 1. in the loop header (first iteration) builder.set_insert_point(loop_header->get_inst_list().back()); - builder.create_barrier(); assert(a && b); builder.create_prefetch_s(a->get_incoming_value(0), /*inc*/ 0); builder.create_prefetch_s(b->get_incoming_value(0), /*inc*/ 0); // 2. at the end of the loop body (next iteration) builder.set_insert_point(loop_body->get_inst_list().back()); - builder.create_barrier(); builder.create_prefetch_s(a->get_incoming_value(1), /*inc*/ 1); builder.create_prefetch_s(b->get_incoming_value(1), /*inc*/ 1); + + prefetched_vals_.insert(a->get_incoming_value(0)); + prefetched_vals_.insert(b->get_incoming_value(0)); + // nested phis + ir::value* next_a = a->get_incoming_value(1); + while (auto* next_a_phi = dynamic_cast(next_a)) { + prefetched_vals_.insert(next_a_phi->get_incoming_value(0)); + next_a = next_a_phi->get_incoming_value(1); + } + prefetched_vals_.insert(next_a); + + ir::value* next_b = b->get_incoming_value(1); + while (auto* next_b_phi = dynamic_cast(next_b)) { + prefetched_vals_.insert(next_b_phi->get_incoming_value(0)); + next_b = next_b_phi->get_incoming_value(1); + } + prefetched_vals_.insert(next_b); } // move loads to the beginning of the loop diff --git a/lib/driver/module.cc b/lib/driver/module.cc index c5d04fa4f..abe839d7e 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -324,7 +324,7 @@ void cu_module::init_from_ptx(const std::string& ptx, driver::cu_device* device) } catch(exception::cuda::invalid_ptx const &){ //#ifdef TRITON_LOG_PTX_ERROR - std::cout << ptx << std::endl; + // std::cout << ptx << std::endl; std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; // exit(1); //#endif diff --git a/lib/ir/print.cc b/lib/ir/print.cc index 47e1d9b59..1552193fa 100644 --- a/lib/ir/print.cc +++ b/lib/ir/print.cc @@ -77,6 +77,54 @@ void print(module &mod, std::ostream& os) { } } +void print(function &fn, std::ostream &os) { + // +} + +void print(basic_block &bb, std::ostream &os) { + auto const &predecessors = bb.get_predecessors(); + os << bb.get_name() << ":"; + if(!predecessors.empty()){ + os << " "; + os << "; preds = "; + auto const &predecessors = bb.get_predecessors(); + for(ir::basic_block *pred: predecessors) + os << pred->get_name() << (pred!=predecessors.back()?", ":""); + } + os << std::endl; + for(ir::instruction *inst: bb.get_inst_list()){ + print(*inst, os); + } +} + +void print(instruction &instr, std::ostream &os) { + instruction *inst = &instr; + os << " "; + if(!inst->get_type()->is_void_ty()){ + os << instr.get_name(); + os << " = "; + } + ir::type* type = inst->get_type(); + os << inst->repr() << " " << type->repr(); + ir::instruction::ops_t ops = inst->ops(); + size_t num_ops = inst->get_num_operands(); + if(num_ops > 0) + os << " ";; + for(unsigned i = 0; i < num_ops; i++){ + if(auto *x = dynamic_cast(ops[i])) + os << x->repr(); + else + os << ops[i]->get_name(); + os << (i < num_ops - 1?", ":""); + } + os << ";"; +// os << " ("; +// for(ir::user* usr: inst->get_users()) +// os << get_name(usr, cnt++) << ", " ; +// os << " )"; + os << std::endl; +} + } } diff --git a/python/src/triton.cc b/python/src/triton.cc index 2c4e65f1a..c72f753a5 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -33,7 +33,10 @@ void init_triton_driver(py::module &&m) { CUdevice handle; drv::dispatch::cuDeviceGet(&handle, dev_id); return new drv::cu_device(handle, take_ownership); - })); + })) + .def("max_shared_memory", [](drv::cu_device *self) { + return self->max_shared_memory(); + }); // host device py::class_(m, "host_device") .def(py::init<>()); @@ -75,11 +78,11 @@ void init_triton_driver(py::module &&m) { void init_triton_codegen(py::module &&m) { m.def( - "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps) { + "add_passes_to_emit_bin", [](ir::module &ir, drv::device *dev, int num_warps, int num_stages) { drv::module *mod; drv::kernel *ker; size_t shared_mem; - triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, mod, ker, shared_mem); + triton::codegen::add_passes_to_emit_bin(ir, dev, num_warps, num_stages, mod, ker, shared_mem); std::stringstream ss; ir::print(ir, ss); return std::make_tuple(mod, ker, shared_mem, ss.str()); diff --git a/python/test/test_blocksparse.py b/python/test/test_blocksparse.py index 6e748b6c0..991a34c26 100644 --- a/python/test/test_blocksparse.py +++ b/python/test/test_blocksparse.py @@ -27,7 +27,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B) ra = triton.testing.sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a rb = triton.testing.sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b - rc = op(ra, rb) + rc = triton.testing.catch_oor(lambda : op(ra, rb), pytest) # torch result ta = triton.testing.mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a tb = triton.testing.mask_tensor(b, layout, BLOCK) if MODE == "dds" else b diff --git a/python/test/test_matmul.py b/python/test/test_matmul.py index 163269dc3..b5f60cf4a 100644 --- a/python/test/test_matmul.py +++ b/python/test/test_matmul.py @@ -5,56 +5,69 @@ import torch @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE", itertools.chain( *[ [ # 1 warp - (16, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 16, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (32, 16, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 32, 32, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), - (64, 16, 64, 1, 1, None, None, None, AT, BT, DTYPE), - (16, 64, 64, 1, 1, None, None, None, AT, BT, DTYPE), + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE), # 2 warp - (64, 32, 64, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 64, 1, 2, None, None, None, AT, BT, DTYPE), - (64, 32, 16, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 64, 16, 1, 2, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 2, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE), # 4 warp - (128, 64, 16, 1, 4, None, None, None, AT, BT, DTYPE), - (64, 128, 16, 1, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 32, 1, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 32, 1, 4, None, None, None, AT, BT, DTYPE), - (128, 32, 64, 1, 4, None, None, None, AT, BT, DTYPE), - (32, 128, 64, 1, 4, None, None, None, AT, BT, DTYPE), + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE), # 8 warp - (128, 256, 16, 1, 8, None, None, None, AT, BT, DTYPE), - (256, 128, 16, 1, 8, None, None, None, AT, BT, DTYPE), - (256, 128, 32, 1, 8, None, None, None, AT, BT, DTYPE), - # # split-k - (64, 64, 16, 2, 4, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 4, 4, None, None, None, AT, BT, DTYPE), - (64, 64, 16, 8, 4, None, None, None, AT, BT, DTYPE), - # # variable input - (128, 128, 32, 1, 4, 1024, 1024, 1024, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 384, 128, 640, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 107, 233, 256, AT, BT, DTYPE), - (128, 128, 32, 1, 4, 107, 233, 311, AT, BT, DTYPE), + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE), + # split-k + (64, 64, 16, 2, 4, 2, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 4, 4, 2, None, None, None, AT, BT, DTYPE), + (64, 64, 16, 8, 4, 2, None, None, None, AT, BT, DTYPE), + # variable input + (128, 128, 32, 1, 4, 2, 1024, 1024, 1024, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 2, 384, 128, 640, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 2, 107, 233, 256, AT, BT, DTYPE), + (128, 128, 32, 1, 4, 2, 107, 233, 311, AT, BT, DTYPE), ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] + ], + # n-stage + *[ + [ + (16, 16, 16, 1, 1, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (64, 32, 64, 1, 2, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (128, 64, 16, 1, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (256, 128, 32, 1, 8, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (128, 128, 32, 1, 4, STAGES, 384, 128, 640, AT, BT, DTYPE), + # split-k + (64, 64, 16, 8, 4, STAGES, 1024, 1024, 1024, AT, BT, DTYPE), + (64, 64, 16, 8, 4, STAGES, 1024, 1024, 32, AT, BT, DTYPE), + ] for DTYPE in ["float16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [2, 3, 4] ] ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, DTYPE): torch.manual_seed(0) # nuke kernel decorators -- will set meta-parameters manually META = {'BLOCK_M': BLOCK_M, 'BLOCK_N': BLOCK_N, 'BLOCK_K': BLOCK_K, 'SPLIT_K': SPLIT_K, 'GROUP_M': 8} - configs = [triton.Config(meta=META, num_warps=NWARP)] + configs = [triton.Config(meta=META, num_warps=NWARP, num_stages=NSTAGE)] kernel = triton.ops._matmul.kernel decorators = kernel.kernel_decorators kernel.kernel_decorators = [] @@ -72,5 +85,5 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, M, N, K, AT, BT, DTYPE): b = b.t() if BT else b # run test th_c = torch.matmul(a, b) - tt_c = triton.ops.matmul(a, b) + tt_c = triton.testing.catch_oor(lambda : triton.ops.matmul(a, b), pytest) assert triton.testing.allclose(th_c, tt_c) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index ce4e5e12b..236723057 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -407,13 +407,14 @@ class CodeGenerator(ast.NodeVisitor): class Binary: - def __init__(self, module, kernel, num_warps, shared_mem, ir_asm): + def __init__(self, module, kernel, num_warps, num_stages, shared_mem, ir_asm): # cache ir asm self.ir_asm = ir_asm self.module = module self.kernel = kernel self.shared_mem = shared_mem self.num_warps = num_warps + self.num_stages = num_stages self.sass = None def asm(self, mode): @@ -447,6 +448,13 @@ class CompilationError(Exception): self.message += '\n Error: ' + str(err) super().__init__(self.message) +class OutOfResources(Exception): + def __init__(self, required, limit, name): + self.message = f'out of resource: {name}'\ + f'Required: {required}'\ + f'Hardware limit: {limit}' + super().__init__(self.message) + class Kernel: @staticmethod @@ -513,7 +521,7 @@ class Kernel: def __init__(self, fn): self.fn = fn - def _compile(self, *wargs, device, attributes, constants, num_warps, **meta): + def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages, **meta): # explicitly set device torch.cuda.set_device(device.index) # create IR module @@ -535,10 +543,12 @@ class Kernel: raise CompilationError(self.fn.src, node, e) tt_device = _triton.driver.cu_device(device.index, False) # Compile to machine code - mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps) - return Binary(mod, ker, num_warps, shared_mem, ir_asm) + mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, tt_device, num_warps, num_stages) + if shared_mem > tt_device.max_shared_memory(): + raise OutOfResources(shared_mem, tt_device.max_shared_memory(), "shared memory") + return Binary(mod, ker, num_warps, num_stages, shared_mem, ir_asm) - def __call__(self, *wargs, grid, num_warps=4, **meta): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **meta): # device inference tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] if len(tensor_idxs) == 0: @@ -554,12 +564,12 @@ class Kernel: attr_key = frozenset(attributes.items()) meta_key = frozenset(meta.items()) const_key = frozenset(constants.items()) - key = (device.type, device.index, types_key, attr_key, num_warps, meta_key, const_key) + key = (device.type, device.index, types_key, attr_key, num_warps, num_stages, meta_key, const_key) cache = self.fn.cache if key not in cache: # compile and cache configuration if necessary cache[key] = self._compile( - *wargs, device=device, attributes=attributes, num_warps=num_warps, constants=constants, **meta + *wargs, device=device, attributes=attributes, num_warps=num_warps, num_stages=num_stages, constants=constants, **meta ) # pack arguments fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg.__class__) for i, arg in enumerate(wargs)]) @@ -585,7 +595,7 @@ class Launcher: class Autotuner: def __init__(self, kernel, arg_names, configs, key): if not configs: - self.configs = [Config(dict(), num_warps=4)] + self.configs = [Config(dict(), num_warps=4, num_stages=2)] else: self.configs = configs self.key_idx = [arg_names.index(k) for k in key] @@ -603,7 +613,7 @@ class Autotuner: ) # augment meta-parameters with tunable ones current = dict(meta, **config.meta) - kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, **current) + kernel_call = lambda: self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) return triton.testing.do_bench(kernel_call) def __call__(self, *args, **meta): @@ -616,7 +626,7 @@ class Autotuner: config = self.cache[key] else: config = self.configs[0] - return self.kernel(*args, num_warps=config.num_warps, **meta, **config.meta) + return self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **meta, **config.meta) class JITFunction: @@ -671,9 +681,10 @@ class JITFunction: class Config: - def __init__(self, meta, num_warps=4): + def __init__(self, meta, num_warps=4, num_stages=2): self.meta = meta self.num_warps = num_warps + self.num_stages = num_stages def autotune(configs, key): diff --git a/python/triton/testing.py b/python/triton/testing.py index e67d792c4..27f039d38 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -1,5 +1,6 @@ import torch import os +from .code_gen import OutOfResources try: import triton._C.libtriton.cutlass as _cutlass @@ -8,6 +9,15 @@ except ImportError: _cutlass = None has_cutlass = False +def catch_oor(kernel, pytest_handle=None): + try: + res = kernel() + except OutOfResources as e: + if pytest_handle: + pytest_handle.skip(str(e)) + return None + return res + def sparsify_tensor(x, mask, block): ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)