From 3304629de99ae92e1a171bae907e14c157d1ddd5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 31 Mar 2020 18:55:31 -0400 Subject: [PATCH] [CORE] Fixed several issues that arose in the development of the torch-blocksparse package: * Now using warp shuffle in reductions when possible * Various bugfixes in layout inference * Added INFINITY, exponential and select * Better error messages for unimplemented constructs --- CMakeLists.txt | 2 +- include/triton/codegen/analysis/layout.h | 2 +- include/triton/codegen/selection/generator.h | 2 + include/triton/codegen/transform/peephole.h | 1 + include/triton/ir/builder.h | 5 +- include/triton/ir/enums.h | 2 + include/triton/ir/instructions.h | 11 ++ include/triton/ir/type.h | 1 + include/triton/ir/visitor.h | 4 + include/triton/lang/ast.h | 1 + include/triton/lang/code_gen.h | 39 +++--- include/triton/lang/parser.h | 1 + include/triton/lang/token.h | 2 + lib/codegen/analysis/axes.cc | 2 +- lib/codegen/analysis/layout.cc | 23 ++-- lib/codegen/selection/generator.cc | 132 +++++++++++++++++-- lib/codegen/selection/machine_layout.cc | 1 - lib/codegen/selection/machine_value.cc | 4 +- lib/codegen/transform/cts.cc | 5 +- lib/codegen/transform/peephole.cc | 14 ++ lib/driver/module.cc | 2 +- lib/ir/builder.cc | 13 +- lib/ir/constant.cc | 2 +- lib/ir/instructions.cc | 12 ++ lib/lang/ast.cc | 7 + lib/lang/code_gen.cc | 113 +++++++++------- lib/lang/parser.cc | 8 +- lib/lang/token.cc | 2 + lib/runtime/function.cc | 12 +- python/examples/einsum.py | 68 ++++++---- python/triton/ops/einsum.py | 8 +- tests/bench/dot.cc | 2 +- tests/common/dot.h | 6 +- 33 files changed, 374 insertions(+), 135 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bdb9e1ce7..1d0ad62eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ option(BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) # LLVM find_package(LLVM REQUIRED) +link_directories(${LLVM_LIBRARY_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) @@ -40,6 +41,5 @@ endif() # Triton file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) -link_directories(${LLVM_LIBRARY_DIRS}) target_link_libraries(triton ${LLVM_LIBRARIES}) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 13ddfafb4..289bd7874 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -119,7 +119,7 @@ struct scanline_layout: public data_layout { int mts(size_t k) { return mts_.at(k); } int nts(size_t k) { return nts_.at(k); } -private: +public: std::vector mts_; std::vector nts_; }; diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 8b8c5bf64..6df877486 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -111,6 +111,8 @@ public: void visit_broadcast_inst(ir::broadcast_inst*); void visit_downcast_inst(ir::downcast_inst*); + void visit_exp_inst(ir::exp_inst*); + void visit_get_program_id_inst(ir::get_program_id_inst*); void visit_get_num_program_inst(ir::get_num_program_inst*); void visit_atomic_cas_inst(ir::atomic_cas_inst*); diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index 9382b968d..ce0d1f34a 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -19,6 +19,7 @@ namespace transform{ class peephole { private: + bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder); bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder); bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D); bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index d5782e3a1..1c87997f5 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -35,7 +35,8 @@ public: basic_block* get_insert_block() { return block_; } iterator get_insert_point() { return insert_point_;} // Constants - value *get_int32(unsigned val); + value *get_int32(int32_t val); + value *get_int64(int64_t val); // Types type *get_void_ty(); type *get_int1_ty(); @@ -63,6 +64,7 @@ public: value* create_ret_void(); // Cast instructions value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = ""); + value* create_ptr_to_int(value *src, type *dst_ty, const std::string &name = ""); value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = ""); value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = ""); value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = ""); @@ -135,6 +137,7 @@ public: value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = ""); value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); value *create_atomic_add(value *ptr, value *val, const std::string &name = ""); + value *create_exp(value* arg, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = ""); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 491d37edf..10b017919 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -127,6 +127,8 @@ enum value_id_t: unsigned { INST_ATOMIC_CAS, INST_ATOMIC_EXCH, INST_ATOMIC_ADD, + // math + INST_EXP, // array arithmetic INST_TRANS, INST_REDUCE, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 41eb98eb3..54dd2e736 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -612,6 +612,17 @@ public: static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); }; +class exp_inst: public builtin_inst { +private: + exp_inst(value *val, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "exp"; } + _TRITON_DEFINE_CLONE(exp_inst) + _TRITON_DEFINE_ACCEPT(exp_inst) + +public: + static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); +}; + class dot_inst: public builtin_inst { public: enum TransT { NoTrans, Trans }; diff --git a/include/triton/ir/type.h b/include/triton/ir/type.h index dedc8ea8c..05fb795c4 100644 --- a/include/triton/ir/type.h +++ b/include/triton/ir/type.h @@ -81,6 +81,7 @@ public: bool is_integer_ty() const { return id_ == IntegerTyID; } bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() && get_integer_bitwidth() == bitwidth;} + bool is_bool_ty() const { return is_integer_ty(1); } bool is_pointer_ty() const { return id_ == PointerTyID; } bool is_tile_ty() const { return id_ == TileTyID; } diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index b5941b88f..e53331af4 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -48,6 +48,8 @@ class splat_inst; class broadcast_inst; class downcast_inst; +class exp_inst; + class get_program_id_inst; class get_num_program_inst; class atomic_cas_inst; @@ -114,6 +116,8 @@ public: virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0; virtual void visit_masked_store_inst(masked_store_inst*) = 0; + virtual void visit_exp_inst(exp_inst*) = 0; + virtual void visit_reshape_inst(reshape_inst*) = 0; virtual void visit_splat_inst(splat_inst*) = 0; virtual void visit_broadcast_inst(broadcast_inst*) = 0; diff --git a/include/triton/lang/ast.h b/include/triton/lang/ast.h index eb442de11..2d888efc2 100644 --- a/include/triton/lang/ast.h +++ b/include/triton/lang/ast.h @@ -433,6 +433,7 @@ public: void UnaryArithmOpTypeChecking(); void BitcastOpTypeChecking(); void CastOpTypeChecking(); + void IntrinsicOpTypeChecking(); protected: UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0) diff --git a/include/triton/lang/code_gen.h b/include/triton/lang/code_gen.h index 155706ddc..5347421bb 100644 --- a/include/triton/lang/code_gen.h +++ b/include/triton/lang/code_gen.h @@ -33,8 +33,8 @@ using LocationList = std::vector; using StaticInitList = std::vector; // Error -inline void should_not_happen() { throw std::runtime_error("should not happen"); } -inline void error_not_implemented() { throw std::runtime_error("not implemented"); } +inline void should_not_happen(const std::string& suffix) { throw std::runtime_error("internal compiler error: " + suffix); } +inline void error_not_implemented(const std::string& msg) { throw std::runtime_error(msg); } class Generator: public Visitor { friend class Evaluator; @@ -87,6 +87,9 @@ protected: // Triton-IR attributes ir::attribute GenIRAttr(ASTNode::Attr attr); + // Triton-IR metadata + void SetIRMetadata(ASTNode::Attr attr, ir::value *rhs); + // Triton-IR values ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs); ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty); @@ -131,22 +134,22 @@ public: void VisitObject(Object* obj); void VisitIdentifier(Identifier* ident); - void VisitConditionalOp(ConditionalOp*) { should_not_happen(); } - void VisitFuncCall(FuncCall*) { should_not_happen(); } - void VisitTransOp(TransOp*) { should_not_happen(); } - void VisitEnumerator(Enumerator*) { should_not_happen(); } - void VisitConstant(Constant*) { should_not_happen(); } - void VisitTempVar(TempVar*) { should_not_happen(); } - void VisitDeclaration(Declaration*) { should_not_happen(); } - void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); } - void VisitIfStmt(IfStmt*) { should_not_happen(); } - void VisitForStmt(ForStmt*) { should_not_happen(); } - void VisitJumpStmt(JumpStmt*) { should_not_happen(); } - void VisitReturnStmt(ReturnStmt*) { should_not_happen(); } - void VisitLabelStmt(LabelStmt*) { should_not_happen(); } - void VisitCompoundStmt(CompoundStmt*) { should_not_happen(); } - void VisitFuncDef(FuncDef*) { should_not_happen(); } - void VisitTranslationUnit(TranslationUnit*) { should_not_happen(); } + void VisitConditionalOp(ConditionalOp*) { should_not_happen("conditional cannot be lvalue"); } + void VisitFuncCall(FuncCall*) { should_not_happen("funccall cannot be lvalue"); } + void VisitTransOp(TransOp*) { should_not_happen("transop cannot be lvalue"); } + void VisitEnumerator(Enumerator*) { should_not_happen("enumerator cannot be lvalue"); } + void VisitConstant(Constant*) { should_not_happen("constant cannot be lvalue"); } + void VisitTempVar(TempVar*) { should_not_happen("tempvar cannot be lvalue"); } + void VisitDeclaration(Declaration*) { should_not_happen("declaration cannot be lvalue"); } + void VisitEmptyStmt(EmptyStmt*) { should_not_happen("empty statement cannot be lvalue"); } + void VisitIfStmt(IfStmt*) { should_not_happen("if statement cannot be lvalue"); } + void VisitForStmt(ForStmt*) { should_not_happen("for statement cannot be lvalue"); } + void VisitJumpStmt(JumpStmt*) { should_not_happen("jump statement cannot be lvalue"); } + void VisitReturnStmt(ReturnStmt*) { should_not_happen("return statement cannot be lvalue"); } + void VisitLabelStmt(LabelStmt*) { should_not_happen("label statement cannot be lvalue"); } + void VisitCompoundStmt(CompoundStmt*) { should_not_happen("compound statement cannot be lvalue"); } + void VisitFuncDef(FuncDef*) { should_not_happen("function definition cannot be lvalue"); } + void VisitTranslationUnit(TranslationUnit*) { should_not_happen("translation unit cannot be lvalue"); } ir::value* GenExpr(Expr* expr, ir::value* rhs) { rhs_ = rhs; diff --git a/include/triton/lang/parser.h b/include/triton/lang/parser.h index 05edbb159..f8ea89cf7 100644 --- a/include/triton/lang/parser.h +++ b/include/triton/lang/parser.h @@ -83,6 +83,7 @@ public: Constant* ParseSizeof(); Constant* ParseAlignof(); UnaryOp* ParsePrefixIncDec(const Token* tok); + UnaryOp* ParseUnaryIntrinsicOp(const Token* tok, int op); UnaryOp* ParseUnaryOp(const Token* tok, int op); Expr* ParseDerefOp(const Token* tok); diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index 2552f1769..621c555fb 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -164,7 +164,9 @@ public: ALIGNOF, // _Alignof GENERIC, // _Generic IMAGINARY, // _Imaginary + // function keywords BITCAST, + EXP, // KEYWORD END IDENTIFIER, diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index a01ef9aa1..a83471bdb 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -2,7 +2,7 @@ #include "triton/ir/utils.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" - +#include namespace triton{ diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 8b4a3242a..560198511 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -16,7 +16,9 @@ namespace analysis{ * Helper Functions * * -------------------------------- */ -inline unsigned clamp(unsigned x, unsigned lo, unsigned hi) { +inline unsigned clamp(unsigned x, unsigned a, unsigned b) { + unsigned lo = std::min(a, b); + unsigned hi = std::max(a, b); return std::min(std::max(x, lo), hi); } @@ -97,7 +99,9 @@ data_layout::data_layout(id_t id, order_.resize(axes_.size()); std::iota(order_.begin(), order_.end(), 0); auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){ - return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank(); + std::pair xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; + std::pair yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; + return xx < yy; }); if(*largest){ auto max_contiguous = align->contiguous(*largest); @@ -201,8 +205,9 @@ scanline_layout::scanline_layout(size_t num_warps, for(size_t d = 0; d < shape_.size(); d++) effective_num_threads *= mts_[d]; - if(num_warps * 32 != effective_num_threads) - throw std::runtime_error("cannot create a kernel with this amount of warps"); +// std::cout <& values) { auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); auto cmp = [](ir::value* x, ir::value *y) { - return x->get_type()->get_tile_ranks1() < - y->get_type()->get_tile_ranks1(); + std::pair xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; + std::pair yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()}; + return xx < yy; }; std::vector lvalue = values; std::remove_if(lvalue.begin(), lvalue.end(), [&](ir::value* v) { return dynamic_cast(v); }); @@ -402,11 +408,8 @@ void layouts::run(ir::module &mod) { unsigned axis = red->get_axis(); // shape auto shapes = arg->get_type()->get_tile_shapes(); - unsigned shape_ax = shapes[axis]; scanline_layout *layout = get(arg)->to_scanline(); - unsigned per_thread = layout->nts(axis); - unsigned depth = shape_ax / per_thread; - shapes[axis] = depth; + shapes[axis] = layout->mts(axis); // create layout layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_); tmp_[red] = id; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 8321a9948..e60e2567e 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -196,8 +196,9 @@ void generator::visit_value(ir::value* v) { BasicBlock *current = builder_->GetInsertBlock(); auto *inst = dynamic_cast(v); if(inst && !dynamic_cast(v)) - for(ir::value *op: inst->ops()) + for(ir::value *op: inst->ops()){ visit_value(op); + } // change insert point for phi node builder_->SetInsertPoint(current); auto *phi = dynamic_cast(v); @@ -547,6 +548,24 @@ void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) { vmap_[np] = ret; } +void generator::visit_exp_inst(ir::exp_inst* x){ + distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0)); +// Function *fn = builder_->GetInsertBlock()->getParent(); +// Module *module = fn->getParent(); +// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_); +// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty}); + Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634); + + FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false); + InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false); + + + for_each(x, [&](indices_t idx){ + Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e); + set_value(x, idx, builder_->CreateCall(ex2, {ex2arg})); + }); +} + void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); @@ -587,6 +606,7 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); tgt_->add_memfence(module, *builder_); + tgt_->add_barrier(module, *builder_); builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, @@ -825,24 +845,111 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { ir::value *arg = x->get_operand(0); distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg); ir::reduce_inst::op_t op = x->get_op(); + unsigned axis = x->get_axis(); + + Type *fp32_ty = builder_->getFloatTy(); + FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false); + InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false); + InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false); + auto accumulate = [&](Value* x, Value *y) -> Value* { switch(op) { case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y); case ir::reduce_inst::SUB: return builder_->CreateSub(x, y); - case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y); - case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y); + case ir::reduce_inst::MAX:{ + if(x->getType()->isIntegerTy()) + return builder_->CreateSelect(builder_->CreateICmpSGE(x, y), x, y); + else + return builder_->CreateMaxNum(x, y); + } + case ir::reduce_inst::MIN:{ + if(x->getType()->isIntegerTy()) + return builder_->CreateSelect(builder_->CreateICmpSLE(x, y), x, y); + else + return builder_->CreateMinNum(x, y); + } case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y); case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y); - case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y); - case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y); - default: break; + case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y}); + case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y}); + default: assert(false); return nullptr; } - assert(false); - return nullptr; }; + Value *neutral; + switch(op) { + case ir::reduce_inst::ADD: neutral = builder_->getInt32(0); break; + case ir::reduce_inst::SUB: neutral = builder_->getInt32(0); break; + case ir::reduce_inst::MAX: neutral = builder_->getInt32(INT32_MIN); break; + case ir::reduce_inst::MIN: neutral = builder_->getInt32(INT32_MAX); break; + case ir::reduce_inst::FADD: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break; + case ir::reduce_inst::FSUB: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break; + case ir::reduce_inst::FMAX: neutral = ConstantFP::get(arg_tile->get_ty(), -INFINITY); break; + case ir::reduce_inst::FMIN: neutral = ConstantFP::get(arg_tile->get_ty(), INFINITY); break; + default: assert(false); break; + } + + + + analysis::data_layout* arg_layout = layouts_->get(arg); + if(auto* L = dynamic_cast(arg_layout)){ + bool can_optimize = true; + for(size_t r = 0; r < L->get_rank(); r++){ + if(r != axis) + can_optimize = can_optimize && (L->mts(r) == L->get_shape()[r]); + } + if(can_optimize){ + Value *thread_acc = nullptr; + // reduce within thread + arg_tile->for_each([&](indices_t idx) { + Value *current = arg_tile->get_value(idx); + if(thread_acc == nullptr) + thread_acc = current; + else + thread_acc = accumulate(thread_acc, current); + }); + // reduce within wrap + FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false); + InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false); + Value *warp_acc = thread_acc; + for(int i = 16; i > 0; i >>= 1) + warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)})); + // shared memory pointer + unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); + Type *res_ty = arg_tile->get_ty(); + Value *sh_mem_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); + Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0); + Value* warp_id = builder_->CreateUDiv(u_thread_id, builder_->getInt32(32)); + Value *write_ptr = builder_->CreateGEP(sh_mem_ptr, warp_id); + // store warp result in shared memory + tgt_->add_barrier(mod_, *builder_); + builder_->CreateStore(warp_acc, write_ptr); + tgt_->add_barrier(mod_, *builder_); + // accumulate all warps + Value *load_ptr = builder_->CreateGEP(sh_mem_ptr, u_thread_id); + Value* is_first_warp = builder_->CreateICmpEQ(warp_id, builder_->getInt32(0)); + BasicBlock* bb_final_acc = BasicBlock::Create(*ctx_, "bb_final_acc", builder_->GetInsertBlock()->getParent()); + BasicBlock* bb_final_acc_done = BasicBlock::Create(*ctx_, "bb_final_acc_done", builder_->GetInsertBlock()->getParent()); + builder_->CreateCondBr(is_first_warp, bb_final_acc, bb_final_acc_done); + builder_->SetInsertPoint(bb_final_acc); + Value* final_val = builder_->CreateLoad(load_ptr); + for(int i = (num_warps_+1)/2; i > 0; i >>= 1) + final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)})); + builder_->CreateStore(final_val, load_ptr); + builder_->CreateBr(bb_final_acc_done); +// // store first warp done + builder_->SetInsertPoint(bb_final_acc_done); + // write back + tgt_->add_barrier(mod_, *builder_); + final_val = builder_->CreateLoad(sh_mem_ptr); + for_each(x, [&](indices_t idx) { + set_value(x, idx, final_val); + }); + return; + } + } + // reduce within thread - unsigned axis = x->get_axis(); arg_tile->for_each([&](indices_t idx) { indices_t pidx = idx; pidx[axis] = builder_->getInt32(0); @@ -861,7 +968,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { unsigned depth = stile->get_shapes()[axis]; unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); - Type *res_ty = builder_->getFloatTy(); + Type *res_ty = arg_tile->get_ty(); Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); for(auto& x: partial) { // current element being computed @@ -891,10 +998,12 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { // accumulate result = accumulate(result, next); // write back + tgt_->add_barrier(mod_, *builder_); builder_->CreateStore(result, write_ptr); } } tgt_->add_barrier(mod_, *builder_); + // write back for_each(x, [&](indices_t idx) { indices_t red_idx = idx; @@ -1169,8 +1278,9 @@ void generator::visit_function(ir::function* fn) { } builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]); // initialize layouts - for(auto x: layouts_->get_all()) + for(auto x: layouts_->get_all()){ visit_layout(x.second); + } // generate LLVM-IR code for(ir::basic_block *block: fn->blocks()) visit_basic_block(block); diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc index 7ed00c586..9ee45db1c 100644 --- a/lib/codegen/selection/machine_layout.cc +++ b/lib/codegen/selection/machine_layout.cc @@ -158,7 +158,6 @@ tile *machine_distributed_layout::create(ir::value *v) { return false; }; std::sort(order.begin(), order.end(), cmp); - return new distributed_tile(ty, shapes, order, axes, *builder_); } diff --git a/lib/codegen/selection/machine_value.cc b/lib/codegen/selection/machine_value.cc index c70ba85b0..dbff237d1 100644 --- a/lib/codegen/selection/machine_value.cc +++ b/lib/codegen/selection/machine_value.cc @@ -135,13 +135,13 @@ Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& sh const std::vector& perm, const std::vector& order, indices_t idx) { // strides - std::vector strides(order.size()); + std::vector strides(shapes.size(), builder.getInt32(0)); strides[order[0]] = builder.getInt32(1); for(size_t i = 1; i < idx.size(); i++) strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]])); // result Value *result = builder.getInt32(0); - for(size_t i = 0; i < strides.size(); i++) + for(size_t i = 0; i < idx.size(); i++) result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i])); return result; } diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index ae2791cc8..4b2aadb99 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -26,8 +26,6 @@ inline bool is_shmem_res(ir::value* v){ return false; if(i->get_id() == ir::INST_TRANS) return true; - if(i->get_id() == ir::INST_REDUCE) - return true; if(i->get_id() == ir::INST_COPY_TO_SHARED) return true; return false; @@ -76,8 +74,9 @@ void cts::run(ir::module &mod) { size_t num_op = i->get_num_operands(); // copy to shared operands for(size_t k = 0; k < num_op; k++) - if(is_shmem_op(i, k)) + if(is_shmem_op(i, k)){ add_copy(i, i->get_operand(k), builder, true); + } // copy from shared operands for(size_t k = 0; k < num_op; k++) if(!dynamic_cast(i) && diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index 0720318db..69498b127 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -83,6 +83,19 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ } } +bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){ + auto cfs = dynamic_cast(value); + if(cfs) { + ir::value *arg = cfs->get_operand(0); + ir::copy_to_shared_inst* cts = dynamic_cast(arg); + if(!cts) + return false; + cfs->replace_all_uses_with(cts->get_operand(0)); + return true; + } + +} + bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ auto x = dynamic_cast(value); if(!x) @@ -183,6 +196,7 @@ void peephole::run(ir::module &mod) { continue; bool was_modified = false; was_modified = was_modified || rewrite_mult(i, builder); + was_modified = was_modified || rewrite_cts_cfs(i, builder); was_modified = was_modified || rewrite_trans_phi(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); diff --git a/lib/driver/module.cc b/lib/driver/module.cc index f117c7f29..3a48dfdd1 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -91,7 +91,7 @@ void module::compile_llvm_module(std::unique_ptr module, const std const std::string& features, file_type_t ft) { init_llvm(); - // debug +// // debug // llvm::legacy::PassManager pm; // pm.add(llvm::createPrintModulePass(llvm::outs())); // pm.add(llvm::createVerifierPass()); diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index b1e417e5f..5f67b93af 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -44,9 +44,11 @@ void builder::set_insert_point(basic_block *block){ // convenience functions //===----------------------------------------------------------------------===// -value *builder::get_int32(unsigned val) { - return constant_int::get(type::get_int32_ty(ctx_), val); -} +value *builder::get_int32(int32_t val) +{ return constant_int::get(type::get_int32_ty(ctx_), val);} + +value *builder::get_int64(int64_t val) +{ return constant_int::get(type::get_int64_ty(ctx_), val);} type *builder::get_void_ty() { return type::get_void_ty(ctx_); } @@ -103,6 +105,7 @@ value *builder::create_ret_void() { return create_cast(OPCODE, src, dst_ty, name);\ } +DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt) DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP) DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP) DEFINE_CAST_INSTR(fp_to_si, cast_op_t::FPToSI) @@ -308,6 +311,10 @@ value *builder::create_atomic_add(value *ptr, value *val, const std::string &nam return insert(atomic_add_inst::create(ptr, val, name)); } +value *builder::create_exp(value *arg, const std::string &name){ + return insert(exp_inst::create(arg, name)); +} + value *builder::create_dot(value *A, value *B, value *C, const std::string &name) { return insert(dot_inst::create_nn(A, B, C, name)); } diff --git a/lib/ir/constant.cc b/lib/ir/constant.cc index 8a3f1a343..825a130f4 100644 --- a/lib/ir/constant.cc +++ b/lib/ir/constant.cc @@ -64,7 +64,7 @@ constant *constant_fp::get_negative_zero(type *ty){ constant *constant_fp::get_zero_value_for_negation(type *ty) { if(ty->get_scalar_ty()->is_floating_point_ty()) - return get_negative_zero(ty); + return constant_fp::get(ty, 0); return constant::get_null_value(ty); } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 930c4a116..2c14a1e83 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -746,6 +746,18 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string & return new atomic_add_inst(ptr, val, name, next); } +// exp + +exp_inst::exp_inst(value *val, const std::string &name, instruction *next) + : builtin_inst(val->get_type(), INST_EXP, 1, name, next) { + set_operand(0, val); +} + +instruction* exp_inst::create(value *val, const std::string& name, instruction *next) { + return new exp_inst(val, name, next); +} + + //===----------------------------------------------------------------------===// // intrinsic instructions //===----------------------------------------------------------------------===// diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index 30031c757..b6169b78f 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -655,6 +655,9 @@ void UnaryOp::TypeChecking() { case Token::REDUCE: return ReduceOpTypeChecking(); + case Token::EXP: + return IntrinsicOpTypeChecking(); + default: assert(false); } @@ -769,6 +772,10 @@ void UnaryOp::CastOpTypeChecking() { } } +void UnaryOp::IntrinsicOpTypeChecking() { + type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_FLOAT)); +} + /* * Transposition Operator */ diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 41e8afe98..6b82b8b26 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -42,8 +42,8 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { case '^': return set_ret(bld_->create_xor(lhs, rhs)); case Token::LEFT: return set_ret(bld_->create_shl(lhs, rhs)); case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs)); - case '.': return error_not_implemented(); - case ',': return error_not_implemented(); + case '.': return error_not_implemented(". binary operator not implemented"); + case ',': return error_not_implemented(", binary operator not implemented"); case '@' : { ir::type* ret_ty = GenIRType(binary->Type(), *ctx_); ir::type* ret_scal_ty = ret_ty->get_scalar_ty(); @@ -66,7 +66,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { auto clhs = dynamic_cast(lhs); auto crhs = dynamic_cast(rhs); if(!clhs || !crhs) - should_not_happen(); + error_not_implemented("ellipsis between variables not implemented"); return set_ret(bld_->insert(ir::make_range::create(clhs, crhs))); } case '+': @@ -97,7 +97,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { else if(!sign) return set_ret(bld_->create_udiv(lhs, rhs)); else - return should_not_happen(); + return should_not_happen("/ should not encounter type not in {float, int}"); case '%': if(flt) return set_ret(bld_->create_frem(lhs, rhs)); @@ -113,7 +113,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { else if(!sign) return set_ret(bld_->create_icmpULT(lhs, rhs)); else - return should_not_happen(); + return should_not_happen("< should not encounter type not in {float, int}"); case '>': if(flt) return set_ret(bld_->create_fcmpOGT(lhs, rhs)); @@ -122,7 +122,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { else if(!sign) return set_ret(bld_->create_icmpUGT(lhs, rhs)); else - return should_not_happen(); + return should_not_happen("> should not encounter type not in {float, int}"); case Token::LE: if(flt) return set_ret(bld_->create_fcmpOLE(lhs, rhs)); @@ -131,7 +131,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { else if(!sign) return set_ret(bld_->create_icmpULE(lhs, rhs)); else - return should_not_happen(); + return should_not_happen("<= should not encounter type not in {float, int}"); case Token::GE: if(flt) return set_ret(bld_->create_fcmpOGE(lhs, rhs)); @@ -140,7 +140,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { else if(!sign) return set_ret(bld_->create_icmpUGE(lhs, rhs)); else - return should_not_happen(); + return should_not_happen(">= should not encounter type not in {float, int}"); case Token::EQ: if(flt) return set_ret(bld_->create_fcmpOEQ(lhs, rhs)); @@ -152,9 +152,9 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { else return set_ret(bld_->create_icmpNE(lhs, rhs)); default: - error_not_implemented(); + return error_not_implemented("binary operator " + std::to_string(binary->op_) + " not implemented"); } - error_not_implemented(); + should_not_happen(""); } ir::reduce_inst::op_t reduce_op(int tag, bool is_float) { @@ -166,7 +166,7 @@ ir::reduce_inst::op_t reduce_op(int tag, bool is_float) { case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN; default: break; } - should_not_happen(); + error_not_implemented("reduction operator " + std::to_string(tag) + " not implemented"); return reduce_inst::op_t(); } @@ -176,7 +176,10 @@ ir::value* Generator::GenUnaryMinus(ir::value* arg) { ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty); if(ty->is_tile_ty()) _0 = bld_->create_splat(_0, ty->get_tile_shapes()); - return bld_->create_sub(_0, arg); + if(sca_ty->is_floating_point_ty()) + return bld_->create_fsub(_0, arg); + else + return bld_->create_sub(_0, arg); } void Generator::VisitUnaryOp(UnaryOp* unary) { @@ -187,18 +190,19 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { ir::type *arg_scal_ty = arg_ty->get_scalar_ty(); // return switch (unary->op_) { - case Token::PREFIX_INC: return error_not_implemented(); - case Token::PREFIX_DEC: return error_not_implemented(); - case Token::POSTFIX_INC: return error_not_implemented(); - case Token::POSTFIX_DEC: return error_not_implemented(); - case Token::ADDR: return error_not_implemented(); + case Token::PREFIX_INC: return error_not_implemented("prefix increment not implemented"); + case Token::PREFIX_DEC: return error_not_implemented("prefix decrement not implemented"); + case Token::POSTFIX_INC: return error_not_implemented("postfix increment not implemented"); + case Token::POSTFIX_DEC: return error_not_implemented("postfix decrement not implemented"); + case Token::ADDR: return error_not_implemented("unary & not implemented"); case Token::DEREF: return set_ret(bld_->create_load(arg)); - case Token::PLUS: return error_not_implemented(); + case Token::PLUS: return error_not_implemented("unary + not implemented"); case Token::MINUS: return set_ret(GenUnaryMinus(arg)); - case '~': return error_not_implemented(); - case '!': return error_not_implemented(); + case '~': return error_not_implemented("unary ~ not implemented"); + case '!': return error_not_implemented("unary ! not implemented"); case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_))); case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); + case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast case Token::REDUCE: { int ax, tag; UnaryOp::decodeRed(unary->info_, ax, tag); @@ -206,9 +210,9 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { ir::reduce_inst::op_t op = reduce_op(tag, is_float); return set_ret(bld_->create_reduce(arg, op, ax)); } - default: error_not_implemented(); + default: error_not_implemented("unary " + std::to_string(unary->op_) + " not implemented"); } - return error_not_implemented(); + return should_not_happen(""); } void Generator::VisitTransOp(TransOp *trans) { @@ -225,7 +229,9 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) { ir::value* true_val = ret_; VisitExpr(condOp->exprFalse_); ir::value* false_val = ret_; - if(ir::load_inst* ld = dynamic_cast(true_val)) { + if(ir::unmasked_load_inst* ld = dynamic_cast(true_val)) { + if(!false_val->get_type()->is_tile_ty()) + false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val); @@ -233,7 +239,8 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) { ld->erase_from_parent(); return set_ret(new_ld); } - return error_not_implemented(); + return set_ret(bld_->create_select(cond, true_val, false_val)); +// return error_not_implemented(); } void Generator::VisitFuncCall(FuncCall* funcCall) { @@ -244,7 +251,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { if(auto axis = dynamic_cast(ret)) return set_ret(bld_->create_get_program_id(axis->get_value())); else - return should_not_happen(); + return should_not_happen("get_program_id argument should be constant"); } if(name == "get_num_programs"){ VisitExpr(funcCall->Args()->at(0)); @@ -252,7 +259,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { if(auto axis = dynamic_cast(ret)) return set_ret(bld_->create_get_num_program(axis->get_value())); else - return should_not_happen(); + return should_not_happen("get_num_programs argument should be constant"); } if(name == "atomic_cas"){ VisitExpr(funcCall->Args()->at(0)); @@ -294,7 +301,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { ir::value* false_val = ret_; return set_ret(bld_->create_select(cond, true_val, false_val)); } - return error_not_implemented(); + return error_not_implemented("function calls not implemented"); } void Generator::VisitObject(Object* obj) { @@ -302,7 +309,7 @@ void Generator::VisitObject(Object* obj) { } void Generator::VisitEnumerator(Enumerator* enumer) { - return error_not_implemented(); + return error_not_implemented("enumeration not implemented"); } void Generator::VisitIdentifier(Identifier* ident) { @@ -316,31 +323,36 @@ void Generator::VisitConstant(Constant* cons) { return set_ret(ir::constant_int::get(type, cons->IVal())); if(ctype->IsFloat() && ctype->IsReal()) return set_ret(ir::constant_fp::get(type, cons->FVal())); - return error_not_implemented(); + return error_not_implemented("constant of type not in {int, float} not implemented"); } void Generator::VisitTempVar(TempVar* tempVar) { - return error_not_implemented(); + return error_not_implemented("temporary variable not implemented"); } // Statement void Generator::VisitDeclaration(Declaration* decl) { auto obj = decl->obj_; // initialize to undef + ir::type* ty = GenIRType(obj->Type(), *ctx_); ir::value* val = ir::undef_value::get(ty); +//obj->GetAttrList() // compute initializers std::vector inits; for (const Initializer& init: decl->Inits()) { VisitExpr(init.expr_); - inits.push_back(ret_); + ir::value *val = ret_; + for(const auto& attr: obj->GetAttrList()) + SetIRMetadata(attr, val); + inits.push_back(val); } // initialize declaration ir::type::id_t id = ty->get_type_id(); if(id == ir::type::StructTyID) - should_not_happen(); + error_not_implemented("struct not implemented"); if(inits.size() > 1) - should_not_happen(); + error_not_implemented("initializer list > 1 element not implemented"); if(inits.size() > 0) val = inits[0]; assert(val->get_type() == ty); @@ -427,20 +439,20 @@ void Generator::VisitForStmt(ForStmt *forStmt) { } void Generator::VisitJumpStmt(JumpStmt* jumpStmt) { - return error_not_implemented(); + return error_not_implemented("jump not implemented"); } void Generator::VisitReturnStmt(ReturnStmt* returnStmt) { ir::value *ret; if(returnStmt->expr_) - return error_not_implemented(); + return error_not_implemented("non-void return not implemented"); else ret = bld_->create_ret_void(); return set_ret(ret); } void Generator::VisitLabelStmt(LabelStmt* labelStmt) { - return error_not_implemented(); + return error_not_implemented("label not implemented"); } void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) { @@ -458,7 +470,7 @@ void Generator::VisitFuncDef(FuncDef* funcDef) { FuncType* type = funcDef->FuncType(); auto prototype = dynamic_cast(GenIRType(type, *ctx_)); if(!prototype) - should_not_happen(); + should_not_happen("could not parse function prototype"); ir::function *fn = mod_->get_or_insert_function(name, prototype); std::vector args = fn->args(); size_t i = 0; @@ -529,7 +541,7 @@ ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) { for(size_t d = 0; d < padded_shapes.size(); d++){ if(dst_shapes[d] != padded_shapes[d] && padded_shapes[d] != 1) - should_not_happen(); + should_not_happen("broadcast should not happen between these shapes"); } // pad and broadcast ir::value *padded = bld_->create_reshape(src, padded_shapes); @@ -555,6 +567,9 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) { bool dst_signed = false; if(src_scalar_ty == dst_scalar_ty) return src; + else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_bool_ty()) + return bld_->create_icmpNE(bld_->create_ptr_to_int(src, ir::tile_type::get_same_shapes(bld_->get_int64_ty(), src->get_type())), + bld_->create_splat(bld_->get_int64(0), src->get_type()->get_tile_shapes())); else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty()) return bld_->create_si_to_fp(src, dst_ty); else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty()) @@ -575,7 +590,7 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) { else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty()) return bld_->create_cast(ir::BitCast, src, dst_ty); else{ - should_not_happen(); + error_not_implemented("cast type not implemented"); return nullptr; } } @@ -594,7 +609,7 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) { if(attr.kind == ASTNode::Attr::MULTIPLEOF) { VisitExpr(attr.vals[0]); auto cst = dynamic_cast(ret_); - if(!cst) should_not_happen(); + if(!cst) should_not_happen("multipleof only works on constants"); return ir::attribute(ir::multiple_of, cst->get_value()); } if(attr.kind == ASTNode::Attr::ALIGNED) { @@ -608,7 +623,15 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) { return ir::attribute(ir::readonly); if(attr.kind == ASTNode::Attr::WRITEONLY) return ir::attribute(ir::writeonly); - should_not_happen(); + error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented"); +} + +void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) { + auto *i = dynamic_cast(v); + if(!i) + return; + if(attr.kind == ASTNode::Attr::MULTIPLEOF) + i->set_metadata(ir::metadata::multiple_of, GenIRAttr(attr).get_value()); } // Triton-IR Types @@ -684,12 +707,12 @@ ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) { } ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) { - error_not_implemented(); + error_not_implemented("struct not implemented"); return nullptr; } void Generator::AllocObjects(Scope* scope, const FuncDef::ParamList& params) { - return error_not_implemented(); + return error_not_implemented("alloc not implemented"); } // SSA @@ -704,7 +727,7 @@ void Generator::popScope() { // LValue Generator void LValAssigner::VisitBinaryOp(BinaryOp* binary) { if(binary->op_ != Token::MASKED_DEREF) - error_not_implemented(); + error_not_implemented("lvalue for binary non masked-deref not implemented"); gen_->VisitExpr(binary->lhs_); ir::value* mask = gen_->ret_; gen_->VisitExpr(binary->rhs_); @@ -714,7 +737,7 @@ void LValAssigner::VisitBinaryOp(BinaryOp* binary) { void LValAssigner::VisitUnaryOp(UnaryOp* unary) { if(unary->op_ != Token::DEREF) - should_not_happen(); + error_not_implemented("lvalue for unary non deref not implemented"); gen_->VisitExpr(unary->operand_); ir::value* addr = gen_->ret_; ret_ = gen_->bld_->create_store(addr, rhs_); diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index acf9167b1..ba85a04cf 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -258,7 +258,6 @@ Constant* Parser::ParseFloat(const Token* tok) { } if (str[end] != 0) Error(tok, "invalid suffix"); - return Constant::New(tok, tag, val); } @@ -571,6 +570,7 @@ Expr* Parser::ParseUnaryExpr() { case Token::SIZEOF: return ParseSizeof(); case Token::INC: return ParsePrefixIncDec(tok); case Token::DEC: return ParsePrefixIncDec(tok); + case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions case '&': return ParseUnaryOp(tok, Token::ADDR); case '*': return ParseDerefOp(tok); case '+': return ParseUnaryOp(tok, Token::PLUS); @@ -634,6 +634,12 @@ UnaryOp* Parser::ParsePrefixIncDec(const Token* tok) { return UnaryOp::New(op, operand); } +UnaryOp* Parser::ParseUnaryIntrinsicOp(const Token* tok, int op) { + ts_.Expect('('); + auto operand = ParseExpr(); + ts_.Expect(')'); + return UnaryOp::New(op, operand); +} UnaryOp* Parser::ParseUnaryOp(const Token* tok, int op) { auto operand = ParseCastExpr(); diff --git a/lib/lang/token.cc b/lib/lang/token.cc index e5d395f8b..d158062e8 100644 --- a/lib/lang/token.cc +++ b/lib/lang/token.cc @@ -45,6 +45,7 @@ const std::unordered_map Token::kwTypeMap_ { { "volatile", Token::VOLATILE }, { "while", Token::WHILE }, { "bitcast", Token::BITCAST }, + { "exp", Token::EXP }, { "_Alignas", Token::ALIGNAS }, { "_Alignof", Token::ALIGNOF }, { "_Atomic", Token::ATOMIC }, @@ -147,6 +148,7 @@ const std::unordered_map Token::tagLexemeMap_ { { Token::VOLATILE, "volatile" }, { Token::WHILE, "while" }, { Token::BITCAST, "bitcast" }, + { Token::EXP, "exp" }, { Token::ALIGNAS, "_Alignas" }, { Token::ALIGNOF, "_Alignof" }, { Token::ATOMIC, "_Atomic" }, diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 85424acde..cc3c395c7 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -165,8 +165,10 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, arg_type ty = arg_i.type(); if(ty != param_tys_.at(i)) throw std::runtime_error("invalid type for argument " + std::to_string(i)); - if(ty == BUFFER_T) - bin_->setArg(i, *((driver::buffer**)arg_i.data())); + if(ty == BUFFER_T){ + driver::buffer* buf = *((driver::buffer**)arg_i.data()); + bin_->setArg(i, buf->size() == 0 ? nullptr : buf); + } else bin_->setArg(i, size_of(ty), arg_i.data()); } @@ -216,6 +218,7 @@ std::unique_ptr function::make_bin(ir::module &module, codegen::transform::cts cts; codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps); // run passes +// ir::print(module, std::cout); dce.run(module); disassociate.run(module); dce.run(module); @@ -231,6 +234,7 @@ std::unique_ptr function::make_bin(ir::module &module, dce.run(module); reassociate.run(module); cts.run(module); + peephole.run(module); dce.run(module); align.run(module); axes.run(module); @@ -238,7 +242,7 @@ std::unique_ptr function::make_bin(ir::module &module, liveness.run(module); allocation.run(module); if(allocation.allocated_size() > context->device()->max_shared_memory()) - return std::unique_ptr(); + throw std::runtime_error("using too much shared memory"); barriers.run(module); isel.visit(module, *llvm); std::unique_ptr res(driver::module::create(context, std::move(llvm))); @@ -391,6 +395,8 @@ std::string function::preheader() { #define __aligned(A) __attribute__((aligned(A))) #define __multipleof(A) __attribute__((multipleof(A))) +#define INFINITY bitcast(0x7F800000) + extern int atomic_cas(int*, int, int); extern int atomic_xchg(int*, int); extern int get_program_id(int); diff --git a/python/examples/einsum.py b/python/examples/einsum.py index ce6d49210..4c8b4ac56 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -13,18 +13,18 @@ configs = [] # Matrix multiplication MNK = [ - (512, 512 ,512), + (1024, 1024, 1024), (2048, 2048, 2048), - #(8192, 8192, 8192), + (8192, 8192, 8192), - (64, 64, 64000), - (64, 64, 128000), - (256, 256, 64000), - (256, 256, 128000), + #(64, 64, 64000), + #(64, 64, 128000), + #(256, 256, 64000), + #(256, 256, 128000), - (1536, 16, 1536), - (1536, 32, 1536), - (1536, 64, 1536), + #(1536, 16, 1536), + #(1536, 32, 1536), + #(1536, 64, 1536), # (1536, 128, 1536), # (4096, 16, 4096), # (4096, 32, 4096), @@ -33,9 +33,9 @@ MNK = [ # (127008, 768, 576) ] -for M, N, K in MNK: - matmul = lambda a, b: torch.matmul(a, b) - configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] +#for M, N, K in MNK: +# matmul = lambda a, b: torch.matmul(a, b) +# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] #for M, N, K in MNK: # matmul = lambda a, b: torch.matmul(a.t(), b) # configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] @@ -94,8 +94,8 @@ for N, C, H, K, R in NCHKR: # 2D Dense convolution NCHWKRS = [ - #(8, 64, 128, 128, 768, 3, 3), - (128, 3, 32, 32, 64, 3, 3), + (8, 64, 128, 128, 768, 3, 3), + #(128, 3, 32, 32, 64, 3, 3), #(8, 256, 32, 32, 512, 3, 3), #(8, 512, 32, 32, 1024, 3, 3) ] @@ -160,22 +160,39 @@ for N, C, H, W, K, R, S in NCHWKRS: b = b.permute(1, 0) b = b.reshape(b.shape[0], b.shape[1], 1, 1) return torch.nn.functional.conv2d(a, b) - configs += [([N, C, H, W], - [C, K], - [N, K, H, W], - shift_conv, - 'nc(h + sh[c])(w + sw[c]),ck->nkhw', - {'sh': shift_h, 'sw': shift_w})] + configs += [([N, C, H, W], [C, K], [N, K, H, W], + shift_conv, + 'nc(h + sh[c])(w + sw[c]),ck->nkhw', + {'sh': shift_h, 'sw': shift_w})] + +NCHWKX = [ + #(8, 64, 128, 128, 128, 7) + ] +for N, C, H, W, K, X in NCHWKX: + off_h = np.array([0, 0, 0, 1, 2, 3, 4], dtype=np.int32) + off_w = np.array([0, 1, 3, 1, 3, 0, 4], dtype=np.int32) + R, S = 5, 5 + def sparse_conv(a, b, **kwargs): + off_h, off_w = kwargs['off_h'], kwargs['off_w'] + K, C, X = b.shape + cvtb = torch.zeros([K, C, R, S], dtype=b.dtype, device=b.device) + cvtb[:, :, off_h, off_w] = b + return torch.nn.functional.conv2d(a, cvtb) + configs += [([N, C, H, W], [K, C, X], [N, K, H - R + 1, W - S + 1], + sparse_conv, + 'nc(h + off_h[x])(w + off_w[x]),kcx->nkhw', + {'off_h': off_h, 'off_w': off_w})] + # Benchmark torch.set_num_threads(1) for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: - dtype = torch.cuda.FloatTensor + dtype = torch.cuda.HalfTensor # initialize input tensors a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda() # triton output - tc = torch.empty(c_shape, device=a.device) + tc = torch.zeros(c_shape, dtype=a.dtype, device=a.device) triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True) # reference output if torch_fn: @@ -185,12 +202,13 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: # performance relative to equivalent matrix multiplication ctx = triton.ops._einsum.registry[tc] B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K - cmp_eqbmm = False + cmp_eqbmm = True if cmp_eqbmm: a = torch.rand(B, M, K).type(dtype).cuda() b = torch.rand(B, K, N).type(dtype).cuda() - tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) - ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms + tmmc = torch.empty([B, M, N]).type(dtype).cuda() + triton.ops.einsum('bmk,bkn->bmn', a, b, tmmc, bench = True) + ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms cmp_str = f'({ratio:4.2f})' else: cmp_str = '' diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 464d588ca..9cee27d70 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -329,14 +329,16 @@ __global__ void {name}( #endif } """ + # print(src) # compilation options - TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16] - TK = 16 if dtype==torch.float16 else 8 + #TM, TN, TB, TZ = [16, 32, 64, 128], [16, 32, 64, 128], 1, [1, 4, 16] + #TK = 16 if dtype==torch.float16 else 8 + TM, TN, TB, TZ, TK = 128, 128, 1, 1, 16 defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} if mask is not None: defines['MASK'] = '{0:#0{1}x}'.format(mask, 10) # create kernel - ret = triton.kernel(src, defines=defines) + ret = triton.kernel(src, defines=defines, num_warps=[4]) # set constant if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: ret.set_constant('AD', delta_a) diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 67a09cfef..fd4a96622 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -13,7 +13,7 @@ int main() { for(auto x: std::vector>{{false, false}}){ std::vector tmp = { // config_t{ord, x[0], x[1], 512, 512, 512}, - config_t{ord, x[0], x[1], 2048, 2048, 2048}, + config_t{ord, x[0], x[1], 1024, 1024, 1024}, // config_t{ord, x[0], x[1], 127008, 768, 576}, // config_t{ord, x[0], x[1], 8192, 8192, 8192} // config_t{ord, x[0], x[1], 16, 2048, 2048}, diff --git a/tests/common/dot.h b/tests/common/dot.h index 427e7ca04..5556f750f 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -109,10 +109,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, opt.num_warps = {nwarp}; } if(mode == BENCH) { - opt.defines.push_back({"TM", {"32", "64", "128"}}); - opt.defines.push_back({"TN", {"32", "64", "128"}}); + opt.defines.push_back({"TM", {"128"}}); + opt.defines.push_back({"TN", {"32"}}); opt.defines.push_back({"TK", {to_string::value == "half" ? "16" : "8"}}); - opt.num_warps = {2, 4, 8}; + opt.num_warps = {4}; } // kernels