From 840140bf266e3a7bbf5b605cd76f0b6fb08a3bf7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 7 May 2021 17:54:37 -0400 Subject: [PATCH] [CODEGEN] Removed dedicated reassociate pass to merge it into LLVM isel (#101) This massively simplifies implementation of `reassociate` and also fixes a bunch of bug. The pass could still be improved, but can already be used to generate constant pointer offsets in eg the matmul epilogue --- include/triton/codegen/selection/generator.h | 36 ++- .../triton/codegen/transform/reassociate.h | 49 ---- include/triton/ir/instructions.h | 46 +-- include/triton/ir/visitor.h | 4 +- lib/codegen/analysis/align.cc | 8 - lib/codegen/pass.cc | 7 +- lib/codegen/selection/generator.cc | 134 ++++++--- lib/codegen/transform/reassociate.cc | 267 ------------------ lib/driver/kernel.cc | 29 +- lib/ir/instructions.cc | 34 +-- python/tutorials/03-matrix-multiplication.py | 24 +- tutorials/01-matmul.cc | 233 --------------- 12 files changed, 204 insertions(+), 667 deletions(-) delete mode 100644 include/triton/codegen/transform/reassociate.h delete mode 100644 lib/codegen/transform/reassociate.cc delete mode 100644 tutorials/01-matmul.cc diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 7c3b51ae5..3d33bf3e2 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -74,6 +74,33 @@ struct distributed_axis { Value* thread_id; }; +class adder{ +public: + adder(Builder** builder): builder_(builder) { } + Value* operator()(Value* x, Value* y, const std::string& name = ""); + +private: + Builder** builder_; +}; + +class multiplier{ +public: + multiplier(Builder** builder): builder_(builder) { } + Value* operator()(Value* x, Value* y, const std::string& name = ""); +private: + Builder** builder_; +}; + +class geper{ +public: + geper(Builder** builder): builder_(builder) { } + Value* operator()(Value *ptr, Value* off, const std::string& name = ""); + Value* operator()(Type* ty, Value*ptr, std::vector vals, const std::string& name = ""); + +private: + Builder** builder_; +}; + class generator: public ir::visitor, public analysis::layout_visitor { private: void init_idx(ir::value *x); @@ -143,9 +170,9 @@ public: void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_barrier_inst(ir::barrier_inst*); void visit_async_wait_inst(ir::async_wait_inst*); - void visit_make_range_dyn(ir::make_range_dyn*); +// void visit_make_range_dyn(ir::make_range_dyn*); void visit_make_range(ir::make_range*); - void visit_make_range_sta(ir::make_range_sta*); +// void visit_make_range_sta(ir::make_range_sta*); void visit_undef_value(ir::undef_value*); void visit_constant_int(ir::constant_int*); void visit_constant_fp(ir::constant_fp*); @@ -195,6 +222,11 @@ private: std::map bbs_; std::map> ords_; + // helper for creating llvm values + adder add; + multiplier mul; + geper gep; + }; } diff --git a/include/triton/codegen/transform/reassociate.h b/include/triton/codegen/transform/reassociate.h deleted file mode 100644 index 708de3c73..000000000 --- a/include/triton/codegen/transform/reassociate.h +++ /dev/null @@ -1,49 +0,0 @@ -#ifndef TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H -#define TDL_INCLUDE_IR_CODEGEN_REASSOCIATE_H - -#include -#include -#include - -namespace triton { - -// forward declaration -namespace ir { -class module; -class value; -class builder; -class instruction; -class getelementptr_inst; -} - -namespace codegen{ - -namespace analysis{ -class tiles; -class align; -} - -namespace transform{ - -class reassociate { - struct cst_info { - ir::value* dyn_ptr; - ir::getelementptr_inst* sta_ptr; - }; - -private: - ir::instruction* is_bin_add(ir::value *x); - ir::value *reassociate_idx(ir::value *value, ir::builder &builder, ir::value *&noncst, ir::value *&cst); - ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map &offsets); - -public: - void run(ir::module& module); -}; - -} - -} - -} - -#endif diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 89598ef61..f5dd397e1 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -821,33 +821,33 @@ private: int N_; }; -// On NVIDIA, implementation is such that -// constant_range = nv_dynamic_program_idx + nv_static_program_idx -// so as to enable re-association on nv_static_program_idx which is constant -class make_range_dyn: public instruction { -private: - make_range_dyn(type *ty, const std::string &name, instruction *next); - std::string repr_impl() const { return "nv_dynamic_program_idx"; } - _TRITON_DEFINE_CLONE(make_range_dyn) - _TRITON_DEFINE_ACCEPT(make_range_dyn) +//// On NVIDIA, implementation is such that +//// constant_range = nv_dynamic_program_idx + nv_static_program_idx +//// so as to enable re-association on nv_static_program_idx which is constant +//class make_range_dyn: public instruction { +//private: +// make_range_dyn(type *ty, const std::string &name, instruction *next); +// std::string repr_impl() const { return "nv_dynamic_program_idx"; } +// _TRITON_DEFINE_CLONE(make_range_dyn) +// _TRITON_DEFINE_ACCEPT(make_range_dyn) -public: - static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr); -}; +//public: +// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr); +//}; -class make_range_sta: public constant { -private: - make_range_sta(make_range *range); +//class make_range_sta: public constant { +//private: +// make_range_sta(make_range *range); -public: - static make_range_sta *get(make_range* range); - make_range* get_range() const; - std::string repr() const { return "nv_static_program_idx"; } - _TRITON_DEFINE_ACCEPT(make_range_sta) +//public: +// static make_range_sta *get(make_range* range); +// make_range* get_range() const; +// std::string repr() const { return "nv_static_program_idx"; } +// _TRITON_DEFINE_ACCEPT(make_range_sta) -private: - make_range *range_; -}; +//private: +// make_range *range_; +//}; /* constant range */ diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 7547c749f..6062f15ba 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -144,12 +144,12 @@ public: virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; virtual void visit_barrier_inst(barrier_inst*) = 0; virtual void visit_async_wait_inst(async_wait_inst*) = 0; - virtual void visit_make_range_dyn(make_range_dyn*) = 0; +// virtual void visit_make_range_dyn(make_range_dyn*) = 0; virtual void visit_make_range(make_range*) = 0; virtual void visit_function(function*) = 0; - virtual void visit_make_range_sta(make_range_sta*) = 0; +// virtual void visit_make_range_sta(make_range_sta*) = 0; virtual void visit_undef_value(undef_value*) = 0; virtual void visit_constant_int(constant_int*) = 0; virtual void visit_constant_fp(constant_fp*) = 0; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index a31b1f24f..f12718a8d 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -174,8 +174,6 @@ std::vector align::populate_is_constant(ir::value *v) { return is_constant_.at(v); if(auto *x = dynamic_cast(v)) return add_to_cache(v, {cst_info{true, std::min(x->get_value(), 128)}}, is_constant_); - if(dynamic_cast(v)) - return add_to_cache(v, {cst_info{true, 0}}, is_constant_); if(auto *x = dynamic_cast(v)) return populate_is_constant_phi(x); if(auto *x = dynamic_cast(v)) @@ -322,8 +320,6 @@ std::vector align::populate_max_contiguous_default(ir::value* v) { auto shapes = v->get_type()->get_block_shapes(); if(dynamic_cast(v)) return add_to_cache(v, {shapes[0]}, max_contiguous_); - if(dynamic_cast(v)) - return add_to_cache(v, {shapes[0]}, max_contiguous_); return add_to_cache(v, std::vector(shapes.size(), 1), max_contiguous_); } @@ -486,10 +482,6 @@ std::vector align::populate_starting_multiple(ir::value *v){ return add_to_cache(x, {std::min(x->get_value(), 128)}, starting_multiple_); if(auto *x = dynamic_cast(v)) return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_); - if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {128}, starting_multiple_); - if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_); if(auto *x = dynamic_cast(v)) return populate_starting_multiple_gep(x); if(auto *x = dynamic_cast(v)) diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 884a89859..0af65542f 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -12,7 +12,6 @@ #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/pipeline.h" -#include "triton/codegen/transform/reassociate.h" #include "triton/driver/device.h" #include "triton/driver/kernel.h" #include "triton/driver/module.h" @@ -48,7 +47,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, codegen::transform::membar barriers(&liveness, &layouts, &allocation); codegen::transform::dce dce; codegen::transform::peephole peephole(target.get(), &layouts); - codegen::transform::reassociate reassociate; +// codegen::transform::reassociate reassociate; codegen::transform::coalesce coalesce(&align, &layouts); codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), num_warps); // run passes @@ -76,7 +75,7 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, align.run(ir); dce.run(ir); if (target->is_gpu()) { - reassociate.run(ir); +// reassociate.run(ir); cts.run(ir); } dce.run(ir); @@ -100,4 +99,4 @@ void add_passes_to_emit_bin(ir::module &ir, driver::device *dev, int num_warps, } } // namespace codegen -} // namespace triton \ No newline at end of file +} // namespace triton diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 328e51d05..91f940841 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -23,6 +23,63 @@ namespace codegen{ using namespace llvm; +Value* adder::operator()(Value *x, Value *y, const std::string& name) { + // (x + cst) + y -> (x + y) + cst + if(auto* bin = dyn_cast(x)) + if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) + if(dyn_cast(bin->getOperand(1))){ + return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y), + bin->getOperand(1)); + } + // (x + (y + cst)) -> (x + y) + cst + if(auto* bin = dyn_cast(y)) + if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) + if(dyn_cast(bin->getOperand(1))){ + return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)), + bin->getOperand(1)); + } + + // default + return (*builder_)->CreateAdd(x, y, name); +} + +Value* multiplier::operator()(Value *x, Value *y, const std::string &name) { + // (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2) + if(auto* bin = dyn_cast(x)) + if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) + if(dyn_cast(bin->getOperand(1))) + if(dyn_cast(y)){ + return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y), + (*builder_)->CreateMul(bin->getOperand(1), y)); + } + // default + return (*builder_)->CreateMul(x, y, name); +} + +Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ + // (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2) + if(auto* gep = dyn_cast(ptr)) + if(ConstantInt* cst1 = dyn_cast(gep->idx_begin())) + if(ConstantInt* cst2 = dyn_cast(off)){ + return (*builder_)->CreateGEP(gep->getPointerOperand(), + (*builder_)->CreateAdd(cst1, cst2)); + } + // ptr + (off + cst) -> (ptr + off) + cst + if(auto* bin = dyn_cast(off)) + if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) + if(ConstantInt* cst = dyn_cast(bin->getOperand(1))){ + return (*builder_)->CreateGEP((*builder_)->CreateGEP(ptr, bin->getOperand(0)), + bin->getOperand(1)); + } + // default + return (*builder_)->CreateGEP(ptr, off, name); +} + +//Value* geper::operator()(Type *ty, Value *ptr, std::vector vals, const std::string &name) { +// return (*builder_)->CreateGEP(ty, ptr, vals, name); +//} + + // types #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() @@ -34,7 +91,6 @@ using namespace llvm; // constants #define i32(...) builder_->getInt32(__VA_ARGS__) // ops -#define add(...) builder_->CreateAdd(__VA_ARGS__) #define and_(...) builder_->CreateAnd(__VA_ARGS__) #define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) #define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) @@ -52,7 +108,6 @@ using namespace llvm; #define fmul(...) builder_->CreateFMul(__VA_ARGS__) #define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) #define fsub(...) builder_->CreateFSub(__VA_ARGS__) -#define gep(...) builder_->CreateGEP(__VA_ARGS__) #define icmp(...) builder_->CreateICmp(__VA_ARGS__) #define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) #define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) @@ -64,7 +119,6 @@ using namespace llvm; #define lshr(...) builder_->CreateLShr(__VA_ARGS__) #define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) #define min_num(...) builder_->CreateMinNum(__VA_ARGS__) -#define mul(...) builder_->CreateMul(__VA_ARGS__) #define neg(...) builder_->CreateNeg(__VA_ARGS__) #define phi(...) builder_->CreatePHI(__VA_ARGS__) #define ret(...) builder_->CreateRet(__VA_ARGS__) @@ -144,7 +198,7 @@ generator::generator(analysis::axes *a_axes, target *tgt, unsigned num_warps) : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), - tgt_(tgt), num_warps_(num_warps) { + tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { } @@ -207,8 +261,8 @@ void generator::visit_phi_node(ir::phi_node* x) { * \brief Code Generation for `binary_operator` */ void generator::visit_binary_operator(ir::binary_operator*x) { + using ll = llvm::Instruction::BinaryOps; auto cvt = [](ir::binary_op_t op){ - using ll = llvm::Instruction::BinaryOps; using tt = ir::binary_op_t; switch(op) { case tt::Add: return ll::Add; @@ -235,7 +289,13 @@ void generator::visit_binary_operator(ir::binary_operator*x) { for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; - vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs); + auto op = cvt(x->get_op()); + if(op == ll::Add) + vals_[x][idx] = add(lhs, rhs); + else if(op == ll::Mul) + vals_[x][idx] = mul(lhs, rhs); + else + vals_[x][idx] = bin_op(op, lhs, rhs); } } @@ -248,8 +308,8 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) { std::vector vals; for(auto it= x->idx_begin(); it != x->idx_end(); it++) vals.push_back(vals_[*it][idx]); - Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty()); - vals_[x][idx] = gep(ty, ptr, vals); + assert(vals.size() == 1); + vals_[x][idx] = gep(ptr, vals[0]); } } @@ -640,7 +700,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){ Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634); std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); - InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false); + InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false); for(auto idx: idxs_.at(x)){ Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); vals_[x][idx] = call(ex2, std::vector{ex2arg}); @@ -1576,7 +1636,7 @@ void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ GetElementPtrInst *in_gep = dyn_cast(vals_[arg][idx]); Value *in_base = in_gep->getPointerOperand(); ConstantInt* cst = dyn_cast(in_gep->idx_begin()); - size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize*in_vec : 0; + size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; in_base = cst ? in_base : in_gep; // output ptr info Value* out_base = shared[i].first; @@ -1683,34 +1743,34 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) { call(iasm); } -void generator::visit_make_range_dyn(ir::make_range_dyn* x) { - for(indices_t idx: idxs_.at(x)){ - assert(idx.size() == 1); - if(idx[0] == i32(0)) - vals_[x][idx] = idx[0]; - else{ - BinaryOperator *bin_add = dyn_cast(idx[0]); - assert(bin_add); - vals_[x][idx] = bin_add->getOperand(0); - } - } -} +//void generator::visit_make_range_dyn(ir::make_range_dyn* x) { +// for(indices_t idx: idxs_.at(x)){ +// assert(idx.size() == 1); +// if(idx[0] == i32(0)) +// vals_[x][idx] = idx[0]; +// else{ +// BinaryOperator *bin_add = dyn_cast(idx[0]); +// assert(bin_add); +// vals_[x][idx] = bin_add->getOperand(0); +// } +// } +//} -void generator::visit_make_range_sta(ir::make_range_sta* x) { - for(indices_t idx: idxs_.at(x)){ - assert(idx.size() == 1); - if(idx[0] == i32(0)){ - vals_[x][idx] = idx[0]; - } - else{ - BinaryOperator *bin_add = dyn_cast(idx[0]); - assert(bin_add); - Value *cst = bin_add->getOperand(1); - assert(isa(cst)); - vals_[x][idx] = cst; - } - }; -} +//void generator::visit_make_range_sta(ir::make_range_sta* x) { +// for(indices_t idx: idxs_.at(x)){ +// assert(idx.size() == 1); +// if(idx[0] == i32(0)){ +// vals_[x][idx] = idx[0]; +// } +// else{ +// BinaryOperator *bin_add = dyn_cast(idx[0]); +// assert(bin_add); +// Value *cst = bin_add->getOperand(1); +// assert(isa(cst)); +// vals_[x][idx] = cst; +// } +// }; +//} void generator::visit_make_range(ir::make_range* x) { for(indices_t idx: idxs_.at(x)){ diff --git a/lib/codegen/transform/reassociate.cc b/lib/codegen/transform/reassociate.cc deleted file mode 100644 index 0dad7a19a..000000000 --- a/lib/codegen/transform/reassociate.cc +++ /dev/null @@ -1,267 +0,0 @@ -#include -#include "triton/codegen/transform/reassociate.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/utils.h" - -namespace triton { -namespace codegen{ -namespace transform{ - - -inline ir::instruction* reassociate::is_bin_add(ir::value *x) { - ir::binary_operator *bin_op = dynamic_cast(x); - bool is_bin_add = bin_op && bin_op->get_op()== ir::binary_op_t::Add; - if(is_bin_add) - return (ir::instruction*)x; - return nullptr; -} - -inline bool is_cst(ir::value *x) { - if(dynamic_cast(x)) - return true; - if(dynamic_cast(x)) - return true; - if(auto *v = dynamic_cast(x)) - return is_cst(v->get_operand(0)); - return false; -} - -ir::value *reassociate::reassociate_idx(ir::value *old_value, - ir::builder &builder, - ir::value *&noncst, - ir::value *&cst){ - // value doesn't change by default - ir::value* new_value = old_value; - cst = nullptr; - noncst = old_value; - - // handle retiling - if(ir::instruction* op = dynamic_cast(old_value)){ - auto shapes = op->get_type()->get_block_shapes(); - ir::value *old_arg = op->get_operand(0); - ir::value *new_arg = reassociate_idx(old_arg, builder, noncst, cst); - // retile(x + y) = retile(x) + retile(y) - if(ir::instruction* bin_add = is_bin_add(new_arg)) - if(cst){ - ir::value *old_lhs = bin_add->get_operand(0); - ir::value *old_rhs = bin_add->get_operand(1); - ir::value *new_lhs = nullptr; - ir::value *new_rhs = nullptr; - if(dynamic_cast(op)){ - builder.set_insert_point(op); - new_lhs = builder.create_reshape(old_lhs, shapes); - new_rhs = builder.create_reshape(old_rhs, shapes); - new_value = builder.create_add(new_lhs, new_rhs); - } - if(dynamic_cast(op)){ - builder.set_insert_point(op); - new_lhs = builder.create_broadcast(old_lhs, shapes); - new_rhs = builder.create_broadcast(old_rhs, shapes); - new_value = builder.create_add(new_lhs, new_rhs); - } - if(dynamic_cast(op)){ - builder.set_insert_point(op); - new_lhs = builder.create_splat(old_lhs, shapes); - new_rhs = builder.create_splat(old_rhs, shapes); - new_value = builder.create_add(new_lhs, new_rhs); - } - } - } - - // handle binary addition - if(ir::instruction* op = is_bin_add(old_value)){ - builder.set_insert_point(op); - std::string name = op->get_name(); - ir::value *lhs = reassociate_idx(op->get_operand (0), builder, noncst, cst); - ir::value *rhs = reassociate_idx(op->get_operand(1), builder, noncst, cst); - builder.set_insert_point(op); - // (x + y) + z - if(ir::instruction* bin_lhs = is_bin_add(lhs)){ - ir::value *llhs = bin_lhs->get_operand(0); - ir::value *rlhs = bin_lhs->get_operand(1); - // (cst + x) + y -> cst + (x + y) - if(is_cst(llhs)) - new_value = builder.create_add(llhs, builder.create_add(rlhs, rhs)); - // (x + cst) + y -> cst + (x + y) - if(is_cst(rlhs)) - new_value = builder.create_add(rlhs, builder.create_add(llhs, rhs)); - } - // x + (y + z) - if(ir::instruction* bin_rhs = is_bin_add(rhs)){ - ir::value *lrhs = bin_rhs->get_operand(0); - ir::value *rrhs = bin_rhs->get_operand(1); - // x + (cst + y) -> cst + (x + y) - if(is_cst(lrhs)) - new_value = builder.create_add(lrhs, builder.create_add(rrhs, lhs), cst); - // x + (y + cst) -> cst + (x + y) - if(is_cst(rrhs)) - new_value = builder.create_add(rrhs, builder.create_add(lrhs, lhs), cst); - } - } - // extract constant and non-constant - if(ir::instruction *bin_add = is_bin_add(new_value)){ - ir::value *new_lhs = bin_add->get_operand(0); - ir::value *new_rhs = bin_add->get_operand(1); - if(is_cst(new_lhs)){ - cst = new_lhs; - noncst = new_rhs; - } - if(is_cst(new_rhs)){ - cst = new_rhs; - noncst = new_lhs; - } - } - // clean-up if some re-ordering happened - if(old_value != new_value) - old_value->replace_all_uses_with(new_value); - return new_value; -} - -/* run */ -void reassociate::run(ir::module &mod) { - ir::builder &builder = mod.get_builder(); - - // constant_range -> nv_dynamic_program_idx + nv_static_program_idx - for(ir::function *fn: mod.get_function_list()){ - std::vector ranges; - std::vector rpo = ir::cfg::reverse_post_order(fn); - for(ir::basic_block *block: rpo){ - // iterate through instruction - for(ir::instruction *i: block->get_inst_list()) - for(ir::value* op: i->ops()) - if(auto *range = dynamic_cast(op)) - ranges.push_back(range); - } - - builder.set_insert_point(rpo.front()->get_first_non_phi()); - for(ir::make_range* old_range: ranges){ - ir::value* dyn_range = builder.insert(ir::make_range_dyn::create(old_range->get_type())); - ir::value* static_range = ir::make_range_sta::get(old_range); - ir::value* new_range = builder.create_add(dyn_range, static_range); - old_range->replace_all_uses_with(new_range); - } - } - - // reassociate - std::map infos; - std::set replaced; - size_t n_replaced; - do{ - n_replaced = replaced.size(); - for(ir::function *fn: mod.get_function_list()){ - std::vector rpo = ir::cfg::reverse_post_order(fn); - // iterate through blocks - for(ir::basic_block *block: rpo){ - // iterate through instruction - for(ir::instruction *i: block->get_inst_list()){ - // retiling - if(ir::retile_inst *rt = dynamic_cast(i)) { - ir::value* op = rt->get_operand(0); - if(infos.find(op) != infos.end()){ - builder.set_insert_point(rt); - ir::getelementptr_inst* sta = infos.at(op).sta_ptr; - ir::value* dyn = infos.at(op).dyn_ptr; - ir::value* cst = *sta->idx_begin(); - if(dynamic_cast(rt)) { - auto shapes = rt->get_type()->get_block_shapes(); - ir::value* ndyn = builder.create_broadcast(dyn, shapes); - ir::value* broadcast = builder.create_broadcast(cst, shapes); - ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast}); - infos[rt] = cst_info{ndyn, nsta}; - } - } - } - // getelementptr instruction - if(ir::getelementptr_inst *pz = dynamic_cast(i)){ - if(replaced.find(pz) != replaced.end()) - continue; - // unpack GEP instruction - ir::value* py = pz->get_pointer_operand(); - ir::value* offset = *pz->idx_begin(); - // reassociate index - ir::value *sta = nullptr; - ir::value *dyn = offset; - reassociate_idx(offset, builder, dyn, sta); - if(sta){ - builder.set_insert_point(pz); - ir::value *dyn_ptr = builder.create_gep(py, {dyn}); - ir::value *sta_ptr = builder.create_gep(dyn_ptr, {sta}); - pz->replace_all_uses_with(sta_ptr); - infos[sta_ptr].dyn_ptr = dyn_ptr; - infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr; - replaced.insert(pz); - } - // reassociate pointer argument - if(infos.find(py) != infos.end()){ - builder.set_insert_point(pz); - ir::getelementptr_inst *sta = infos[py].sta_ptr; - ir::value *dyn = infos[py].dyn_ptr; - ir::value *cst = *sta->idx_begin(); - ir::value *off = *pz->idx_begin(); - ir::value *pz_dyn = builder.create_gep(dyn, {off}); - ir::value *pz_sta = builder.create_gep(pz_dyn, {cst}); - pz->replace_all_uses_with(pz_sta); - infos[pz_sta].dyn_ptr = pz_dyn; - infos[pz_sta].sta_ptr = (ir::getelementptr_inst*)pz_sta; - replaced.insert(pz); - } - // reassociate phi-node pointer - if(ir::phi_node* phi = dynamic_cast(py)){ - // only optimize the case where py = phi pa, pz for now - std::vector ops = phi->ops(); - if(ops.size() != 2) - continue; - if(ops[0] != pz && ops[1] != pz) - continue; - // grab incoming - size_t idx_z = (ops[0] == pz) ? 0 : 1; - size_t idx_a = (ops[0] == pz) ? 1 : 0; - // check if pa is known to have constant offset - ir::value *vpa = phi->get_incoming_value(idx_a); - auto it_a = infos.find(vpa); - if(it_a == infos.end()) - continue; - // unpack dynamically/statically offset pointer - ir::value *pa_dyn = it_a->second.dyn_ptr; - ir::getelementptr_inst *pa_sta = it_a->second.sta_ptr; - ir::value *pz = phi->get_incoming_value(idx_z); - // extract offset - ir::value *off = *pa_sta->idx_begin(); - builder.set_insert_point(phi); - ir::phi_node *phi_dyn = builder.create_phi(phi->get_type(), 2); - phi_dyn->add_incoming(pa_dyn, phi->get_incoming_block(idx_a)); - builder.set_insert_point(phi->get_parent()->get_first_non_phi()); - // re-add the offset - ir::value *phi_sta = builder.create_gep(phi_dyn, {off}); - phi_sta->set_name( phi->get_name() + "_sta"); - phi->replace_all_uses_with(phi_sta); - // remove offset from pz - if(auto *x = dynamic_cast(pz)){ - auto insts = x->get_parent()->get_inst_list(); - auto it = std::find(insts.begin(), insts.end(), x); - it++; - builder.set_insert_point(*it); - } - ir::value *_0 = builder.get_int32(0); - if(off->get_type()->is_block_ty()) - _0 = builder.create_splat(_0, off->get_type()->get_block_shapes()); - ir::value *neg_off = builder.create_sub(_0, off); - ir::value *pz_dyn = builder.create_gep(pz, {neg_off}); - phi_dyn->add_incoming(pz_dyn, phi->get_incoming_block(idx_z)); - infos[phi_sta].dyn_ptr = phi_dyn; - infos[phi_sta].sta_ptr = (ir::getelementptr_inst*)phi_sta; - replaced.insert(phi); - } - } - } - } - } - }while(replaced.size() != n_replaced); -} - -} -} -} diff --git a/lib/driver/kernel.cc b/lib/driver/kernel.cc index 5c57e01a6..4d02ac469 100755 --- a/lib/driver/kernel.cc +++ b/lib/driver/kernel.cc @@ -1,22 +1,22 @@ /* Copyright 2015-2017 Philippe Tillet -* -* Permission is hereby granted, free of charge, to any person obtaining -* a copy of this software and associated documentation files -* (the "Software"), to deal in the Software without restriction, -* including without limitation the rights to use, copy, modify, merge, -* publish, distribute, sublicense, and/or sell copies of the Software, -* and to permit persons to whom the Software is furnished to do so, +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, * subject to the following conditions: -* -* The above copyright notice and this permission notice shall be +* +* The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ @@ -81,6 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_); dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_); dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_); + std::cout << n_reg << std::endl; if (shared_optin > 49152){ // std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl; dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 8da0be6ef..9b013f1d5 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -833,27 +833,27 @@ async_wait_inst* async_wait_inst::create(context &ctx, int N, const std::string } -// nv_dynamic_program_idx -make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) - : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { } +//// nv_dynamic_program_idx +//make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) +// : instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { } -make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) { - return new make_range_dyn(ty, name, next); -} +//make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) { +// return new make_range_dyn(ty, name, next); +//} -// nv_static_program_idx -make_range_sta::make_range_sta(make_range *range) - : constant(range->get_type(), 0), range_(range) { } +//// nv_static_program_idx +//make_range_sta::make_range_sta(make_range *range) +// : constant(range->get_type(), 0), range_(range) { } -make_range* make_range_sta::get_range() const -{ return range_; } +//make_range* make_range_sta::get_range() const +//{ return range_; } -make_range_sta* make_range_sta::get(make_range* range) { - static std::map cache; - if(cache.find(range) == cache.end()) - cache.insert({range, new make_range_sta(range)}); - return cache.at(range); -} +//make_range_sta* make_range_sta::get(make_range* range) { +// static std::map cache; +// if(cache.find(range) == cache.end()) +// cache.insert({range, new make_range_sta(range)}); +// return cache.at(range); +//} // make_range diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index d80f3b3a1..527d42e6a 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -137,8 +137,8 @@ def swish(x): @triton.autotune( configs=[ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8}, num_warps=4), + #triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4), ], key=['M', 'N', 'K'], ) @@ -202,11 +202,12 @@ def matmul(a, b, activation=None): c = torch.empty((M, N), device=a.device, dtype=a.dtype) # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) - _matmul[grid]( + pgm = _matmul[grid]( a, b, c, M, N, K, \ a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\ ACTIVATION = activation ) + #print(pgm.asm('ttir')) # return output return c @@ -218,13 +219,14 @@ def matmul(a, b, activation=None): # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS + custom element-wise swish kernel) #torch.manual_seed(0) -a = torch.randn((512, 512), device='cuda', dtype=torch.float16) -b = torch.randn((512, 512), device='cuda', dtype=torch.float16) -c_0 = matmul(a, b, activation=swish) -c_1 = torch.nn.SiLU()(torch.matmul(a, b)) -print(c_0) -print(c_1) -print(triton.testing.allclose(c_0, c_1)) +# a = torch.randn((512, 512), device='cuda', dtype=torch.float16) +# b = torch.randn((512, 512), device='cuda', dtype=torch.float16) +# c_0 = matmul(a, b, activation=None) +# c_1 = torch.matmul(a, b) +# print(c_0) +# print(c_1) +# print(triton.testing.allclose(c_0, c_1)) +# exit() # %% # Benchmark @@ -238,7 +240,7 @@ print(triton.testing.allclose(c_0, c_1)) @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot - x_vals=[256 * i for i in range(2, 33)], # different possible values for `x_name` + x_vals=[8192], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['cublas', 'triton'], # possible values for `line_arg`` line_names=["cuBLAS", "Triton"], # label name for the lines diff --git a/tutorials/01-matmul.cc b/tutorials/01-matmul.cc deleted file mode 100644 index 08f57084d..000000000 --- a/tutorials/01-matmul.cc +++ /dev/null @@ -1,233 +0,0 @@ -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include -#include -#include -#include -#include -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/tools/bench.hpp" -#include "triton/external/half.hpp" -#include "triton/runtime/function.h" -#include -#include -#include "triton/runtime/function.h" - -namespace drv = triton::driver; -namespace rt = triton::runtime; - -namespace src { - - const char *dot = -R"( -#define STM 8 -#define STN 8 - -__global__ void dot(TYPE * A __noalias __readonly __aligned(16), - TYPE * B __noalias __readonly __aligned(16), - TYPE * C __noalias __aligned(16), - float alpha, - int M __retune, - int N __retune, - int K __retune __multipleof(16), - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc __multipleof(8), - int* locks) { - // prologue - int pid = get_program_id(0); - int pidz = get_program_id(2); - int gridm = (M + TM - 1) / TM; - int gridn = (N + TN - 1) / TN; - int width = STM*gridn; - int stm = pid / width; - int RSTM = min(gridm - stm*STM, STM); - int stn = (pid % width) / (RSTM*STN); - int RSTN = min(gridn - stn*STN, STN); - int laneid = pid % (RSTM * RSTN); - int lanem = laneid / RSTN; - int lanen = laneid % RSTN; - int pidm = stm*STM + lanem; - int pidn = stn*STN + lanen; - int rm[TM] = pidm * TM + 0 ... TM; - int rn[TN] = pidn * TN + 0 ... TN; - - // reduction splitting - K = K / TZ; - int rk[TK] = pidz * K + 0 ... TK; - // pointers to operands - int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; - int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; - TYPE* pa[TM, TK] = A + offa; - TYPE* pb[TK, TN] = B + offb; - // prefetches operands - bool checka[TM, TK] = rk[newaxis, :] < K; - bool checkb[TK, TN] = rk[:, newaxis] < K; - TYPE a[TM, TK] = checka ? *pa : 0; - TYPE b[TK, TN] = checkb ? *pb : 0; - // reduction loop - float acc[TM, TN] = 0; - for(int k = K; k > 0; k -= TK){ - bool checka[TM, TK] = k > TK; - bool checkb[TK, TN] = k > TK; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - TYPE anext[TM, TK] = *?(checka)pa; - TYPE bnext[TK, TN] = *?(checkb)pb; - acc += a @ b; - a = anext; - b = bnext; -// __debug_barrier(); - } - acc = acc * alpha; - TYPE c[TM, TN] = acc; - - // epilogue - int rcm[TM] = pidm * TM + 0 ... TM; - int rcn[TN] = pidn * TN + 0 ... TN; - int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :]; - TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = rcm[:, newaxis] < M && - rcn[newaxis, :] < N; -#if (TZ==1) - *?(checkc) pc = c; -#else - // accumulate partial result using spin-locks - int *plock = locks + rid; - int *pcount = plock + get_num_programs(0) * get_num_programs(1); - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % TZ); - atomic_xchg(plock, 0); -#endif -} -)"; - -} - -enum dtype_t { - FLOAT, - HALF, - DOUBLE -}; - -template -struct to_string; - -template<> struct to_string{ - static constexpr const char* value = "half"; -}; - -template<> struct to_string{ - static constexpr const char* value = "float"; -}; - -template<> struct to_string{ - static constexpr const char* value = "double"; -}; - -template -float triton_dot(drv::context* context, drv::stream* stream, - bool AT, bool BT, - int32_t M, int32_t N, int32_t K){ - std::string ty = to_string::value; - size_t dt_nbytes = sizeof(T); - drv::device* device = context->device(); - int32_t lda = AT ? K : M; - int32_t ldb = BT ? N : K; - int32_t ldc = N; - std::vector sa = { "1", "lda" }; - std::vector sb = { "1", "ldb" }; - // inputs - auto dc = std::shared_ptr(drv::buffer::create(context, M*N*dt_nbytes)); - auto da = std::shared_ptr(drv::buffer::create(context, M*K*dt_nbytes)); - auto db = std::shared_ptr(drv::buffer::create(context, K*N*dt_nbytes)); - auto dlocks = std::shared_ptr(drv::buffer::create(context, 1024*1024*2*4)); - // initialize buffers - std::vector hc(M*N); - std::vector ha(M*K); - std::vector hb(K*N); - for(size_t i = 0; i < ha.size(); i++) - ha[i] = (float)rand()/RAND_MAX; - for(size_t i = 0; i < hb.size(); i++) - hb[i] = (float)rand()/RAND_MAX; - stream->write(&*da, true, 0, ha); - stream->write(&*db, true, 0, hb); - // macros - rt::options_t opt; - opt.defines["STRIDE_AK"] = AT? "1" : "lda"; - opt.defines["STRIDE_AM"] = AT? "lda" : "1"; - opt.defines["STRIDE_BK"] = BT? "ldb" : "1"; - opt.defines["STRIDE_BN"] = BT? "1" : "ldb"; - opt.defines["TYPE"] = ty; - opt.defines["TM"] = "128"; - opt.defines["TN"] = "128"; - opt.defines["TK"] = "64" ; - opt.defines["TZ"] = "1"; - opt.num_warps = 4; - // arguments - std::stringstream oss; - rt::add_arg(oss, *da->cu()); - rt::add_arg(oss, *db->cu()); - rt::add_arg(oss, *dc->cu()); - rt::add_arg(oss, (float)1); - rt::add_arg(oss, M); - rt::add_arg(oss, N); - rt::add_arg(oss, K); - rt::add_arg(oss, lda); - rt::add_arg(oss, ldb); - rt::add_arg(oss, ldc); - rt::add_arg(oss, *dlocks->cu()); - // function - rt::function function(src::dot, opt, device); -// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_NV_PTX) << std::endl; - // grid - auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; }; - auto grid = [ceil, M, N](const rt::options_t& x) { - return rt::kernel::grid_t{ceil(M, x.D("TM"))* - ceil(N, x.D("TN")), - (size_t)x.D("TZ")}; - }; - - // metrics - auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function(oss.str(), grid, stream);}, stream); - return tflops(triton_ns); -} - -float bench_dot(drv::context* context, drv::stream* stream, - bool AT, bool BT, - int32_t M, int32_t N, int32_t K, - dtype_t dtype) { - switch(dtype){ - case HALF: return triton_dot(context, stream, AT, BT, M, N, K); - case FLOAT: return triton_dot(context, stream, AT, BT, M, N, K); - case DOUBLE: return triton_dot(context, stream, AT, BT, M, N, K); - default: return 0; - } -} - -int main() { - // initialize default compute device - auto context = triton::driver::backend::contexts::get_default(); - triton::driver::stream* stream = triton::driver::stream::create(context->backend()); - // shapes to benchmark - typedef std::tuple config_t; - std::vector configs = { - {false, false, 8192, 8192, 8192} - }; - // does the work - bool AT, BT; - int32_t M, N, K; - dtype_t dtype = HALF; - for(const auto& c: configs){ - std::tie(AT, BT, M, N, K) = c; - float tflops = bench_dot(context, stream, AT, BT, M, N, K, dtype); - std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K << ", " << tflops << std::endl; - } -}