diff --git a/examples/python/tensorflow/blocksparse.cpp b/examples/python/tensorflow/blocksparse.cpp index b86c6bcab..d6b305fcf 100644 --- a/examples/python/tensorflow/blocksparse.cpp +++ b/examples/python/tensorflow/blocksparse.cpp @@ -101,6 +101,7 @@ typedef struct bsmm_params CUstream stream; } bsmm_params; +template class BlocksparseMatmulOp : public OpKernel { public: explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -152,29 +153,23 @@ public: shape_c.AddDim(params_.K); Tensor* c = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c)); - // grid and block - int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127; - if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){ - blkN = 64; - gridN = (N + 63)/64; - } // allocate locks + int gridN = (N + 63)/64; Tensor* locks; TensorShape shape_l; if (params_.locks > 0) shape_l.AddDim(gridN * params_.locks * 2); OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks)); - // initialize default compute device - triton::runtime::jit jit(ctx); - // matrix multiplication parameters - triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); - triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); -// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat().data(), false); + // wrap tensorflow handles + triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); + triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); + triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat().data(), false); + triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat().data(), false); + // create profile + triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks); // blocksparse matmul - triton::dnn::blocksparse::dot dot(N, params_.K, params_.C); - dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING); + dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING); } private: @@ -185,4 +180,5 @@ private: char bench_string_[256]; }; -REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint("T"), BlocksparseMatmulOp); +REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint("T"), BlocksparseMatmulOp); +REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint("T"), BlocksparseMatmulOp); diff --git a/include/triton/codegen/target.h b/include/triton/codegen/target.h index 118ee919f..c080d1c07 100644 --- a/include/triton/codegen/target.h +++ b/include/triton/codegen/target.h @@ -23,9 +23,11 @@ public: virtual ~target() {} virtual void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) = 0; virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0; + virtual llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder) = 0; virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0; virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0; virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0; + virtual llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0; bool is_gpu() const; private: @@ -37,9 +39,11 @@ public: amd_cl_target(): target(true){} void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn); llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder); + llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder); llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax); llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); + llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); }; class nvidia_cu_target: public target { @@ -47,9 +51,11 @@ public: nvidia_cu_target(): target(true){} void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn); llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder); + llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder); llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax); llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); + llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); }; class cpu_target: public target { @@ -57,9 +63,11 @@ public: cpu_target(): target(false){} void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn); llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder); + llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder); llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax); llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); + llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); }; } diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h index 1fbded42c..266f29803 100644 --- a/include/triton/dnn/base.h +++ b/include/triton/dnn/base.h @@ -28,6 +28,11 @@ #include "triton/runtime/launch_info.h" namespace triton{ + +namespace runtime{ + class jit; +} + namespace dnn{ @@ -37,6 +42,13 @@ enum autotuning_t{ NO_TUNING }; +class base; +struct launch_context_t{ + base *op; + driver::kernel* kernel; + triton::runtime::launch_information info; +}; + typedef std::vector params_t; class base { @@ -49,9 +61,9 @@ protected: private: // initialize - virtual void init_impl(driver::stream *, driver::cu_module *){ } + virtual void init_impl(driver::stream *, driver::cu_module *) = 0; // deinitialize - virtual void deinit_impl(){ } + virtual void deinit_impl() = 0; // enqueue virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, @@ -63,6 +75,8 @@ private: // default parameters virtual std::vector search_space() const; virtual params_t heuristics() const; + // obtain execution jit + std::pair get_profile_impl(driver::stream *stream, std::vector args, autotuning_t autotune); public: // constructor @@ -73,6 +87,8 @@ public: virtual base* clone() const = 0; // enqueue void enqueue(driver::stream* stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); + // get profile + launch_context_t get_launch_context(driver::stream *stream, std::vector args, autotuning_t autotune = PARTIAL_TUNING); private: std::string name_; diff --git a/include/triton/dnn/batchnorm.h b/include/triton/dnn/batchnorm.h index 496e19ae4..8f9053225 100644 --- a/include/triton/dnn/batchnorm.h +++ b/include/triton/dnn/batchnorm.h @@ -37,6 +37,10 @@ namespace dnn{ class batchnorm_forward: public base { private: + // init + void init_impl(driver::stream *, driver::cu_module *) { } + void deinit_impl() { } + // enqueue void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, @@ -69,6 +73,9 @@ private: class batchnorm_backward: public base{ private: + // init + void init_impl(driver::stream *, driver::cu_module *) { } + void deinit_impl() { } // enqueue void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, diff --git a/include/triton/dnn/blocksparse/dot.h b/include/triton/dnn/blocksparse/dot.h index fbd388937..a1df146fe 100644 --- a/include/triton/dnn/blocksparse/dot.h +++ b/include/triton/dnn/blocksparse/dot.h @@ -14,27 +14,34 @@ private: std::vector args, triton::runtime::launch_information info); // number of flops - virtual size_t num_flops() const; + size_t num_flops() const; // comparison for maps - virtual bool operator<(const base& other) const; + bool operator<(const base& other) const; // default parameters - virtual std::vector search_space() const; - virtual params_t heuristics() const; - + std::vector search_space() const; + params_t heuristics() const; + // init + void init_impl(driver::stream *stream, driver::cu_module *module); + // deinit + void deinit_impl(); public: // constructor - dot(int32_t M, int32_t N, int32_t K); + dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks); // triton-c source - virtual void triton_c_src(std::ostream &os) const; + void triton_c_src(std::ostream &os) const; // clone - virtual base* clone() const; + base* clone() const; private: std::string ab_ty_; std::string c_ty_; - int32_t M_; int32_t N_; + int32_t S_; + int32_t C_; int32_t K_; + int32_t BS_; + int32_t nlocks_; + driver::buffer *locks_; }; } diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 1b6f2d778..d81ff872d 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -25,6 +25,7 @@ private: void build_a_deltas(); void build_masks(); void init_impl(driver::stream *, driver::cu_module *); + void deinit_impl() { } // enqueue std::array get_grid(size_t TM, size_t TN); diff --git a/include/triton/dnn/dot.h b/include/triton/dnn/dot.h index 3df8a13a6..6ba3f0b24 100644 --- a/include/triton/dnn/dot.h +++ b/include/triton/dnn/dot.h @@ -10,6 +10,8 @@ class dot: public base { private: // initialize void init_impl(driver::stream *, driver::cu_module *); + void deinit_impl() { } + // enqueue void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 9cee12c68..1921814c9 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -126,9 +126,10 @@ public: value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); // Built-in instruction - value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = ""); value *create_get_range_id(unsigned axis, const std::string &name = ""); + value *create_get_num_program(unsigned axis, const std::string &name = ""); value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = ""); + value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); value *create_atomic_add(value *ptr, value *val, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index d76ebf719..37692d617 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -514,6 +514,19 @@ private: unsigned axis_; }; +class get_num_program_inst: public builtin_inst { +private: + get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next); + std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; } + +public: + static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); + unsigned get_axis() const { return axis_; } + +private: + unsigned axis_; +}; + class atomic_cas_inst: public builtin_inst { private: atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next); @@ -523,6 +536,15 @@ public: static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr); }; +class atomic_exch_inst: public builtin_inst { +private: + atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "atomic_exch"; } + +public: + static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); +}; + class atomic_add_inst: public builtin_inst { private: atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index 538485366..13894d18a 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -80,6 +80,15 @@ private: const constant* axis_; }; +class get_num_program_expression: public builtin_expression{ +public: + get_num_program_expression(node *axis): axis_((constant*)axis) { } + ir::value* codegen(ir::module *mod) const; + +private: + const constant* axis_; +}; + class atomic_cas_expression: public builtin_expression{ public: atomic_cas_expression(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { } @@ -91,6 +100,17 @@ private: const node *val_; }; +class atomic_exch_expression: public builtin_expression{ +public: + atomic_exch_expression(node *ptr, node *val): ptr_(ptr), val_(val) { } + ir::value* codegen(ir::module *) const; + +private: + const node *ptr_; + const node *val_; +}; + + class atomic_add_expression: public builtin_expression{ public: atomic_add_expression(node *ptr, node *val): ptr_(ptr), val_(val) { } diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 645b0b51f..cd2c8941b 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64 %token IF ELSE FOR CONTINUE WHILE %token NEWAXIS ELLIPSIS AT -%token GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST +%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST %start translation_unit %% @@ -121,6 +121,7 @@ identifier /* Built-in */ builtin_expression : GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); } + | GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | SQRT '(' expression ')' { $$ = new sqrt_expression($3); } | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } @@ -130,6 +131,7 @@ builtin_expression | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } | SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); } | ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); } + | ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); } | ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); } ; diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index 83d11035d..af691349d 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -45,8 +45,9 @@ using triton::lang::return_void; "fp64" { return return_impl(FP64, yytext); } "..." { return return_impl(ELLIPSIS, yytext); } "get_range_id" { return return_impl(GET_RANGE_ID, yytext); } +"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); } "__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); } -"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); } +"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); } "__atomic_add" { return return_impl(ATOMIC_ADD, yytext); } "__sum" { return return_impl(REDUCE_SUM, yytext); } "sqrt" { return return_impl(SQRT, yytext); } diff --git a/lib/codegen/optimize_dce.cpp b/lib/codegen/optimize_dce.cpp index d30bf4c1d..9508cfa2e 100644 --- a/lib/codegen/optimize_dce.cpp +++ b/lib/codegen/optimize_dce.cpp @@ -19,7 +19,8 @@ void optimize_dce::run(ir::module &mod) { for(ir::basic_block *block: rpo) for(ir::instruction *i: block->get_inst_list()){ if(dynamic_cast(i) || dynamic_cast(i) || dynamic_cast(i) - || dynamic_cast(i) || dynamic_cast(i)){ + || dynamic_cast(i) || dynamic_cast(i) + || dynamic_cast(i) || dynamic_cast(i) || dynamic_cast(i) ){ work_list.push_back(i); marked.insert(i); } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 5ab9c55f8..ad7e395b1 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -319,8 +319,12 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function(inst)){ - Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis()); - return (Instruction*)offset; + Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis()); + return (Instruction*)result; + } + if(ir::get_num_program_inst* ii = dynamic_cast(inst)){ + Value *result = tgt_->get_num_blocks(builder.GetInsertBlock()->getModule(), builder, ii->get_axis()); + return (Instruction*)result; } if(ir::atomic_cas_inst* ii = dynamic_cast(inst)){ BasicBlock *current = builder.GetInsertBlock(); @@ -331,6 +335,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::functiongetParent()); Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii))); ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace())); + tgt_->add_memfence(module, builder); tgt_->add_barrier(module, builder); builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb); builder.SetInsertPoint(tid_0_bb); @@ -342,10 +347,29 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::functionadd_memfence(module, builder); tgt_->add_barrier(module, builder); Value *res = builder.CreateLoad(ptr); return (Instruction*)res; } + if(ir::atomic_exch_inst* ii = dynamic_cast(inst)){ + BasicBlock *current = builder.GetInsertBlock(); + Module *module = current->getModule(); + Value *rmw_ptr = value(ii->get_operand(0)); + Value *rmw_val = value(ii->get_operand(1)); + Value *tid = tgt_->get_local_id(module, builder, 0); + Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0)); + BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent()); + BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent()); + tgt_->add_memfence(module, builder); + tgt_->add_barrier(module, builder); + builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb); + builder.SetInsertPoint(tid_0_bb); + Value *res = builder.CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System); + builder.CreateBr(tid_0_done_bb); + builder.SetInsertPoint(tid_0_done_bb); + return (Instruction*)res; + } if(ir::atomic_add_inst* ii = dynamic_cast(inst)){ Value *ptr = value(ii->get_operand(0)); Value *val = value(ii->get_operand(1)); @@ -1136,17 +1160,17 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & Value *result_then = builder.CreateLoad(ptr); builder.CreateBr(mask_done_bb); builder.SetInsertPoint(mask_done_bb); - Value *result = nullptr; + Value *current_result = nullptr; if(false_values){ - result = builder.CreatePHI(result_then->getType(), 2); - ((PHINode*)result)->addIncoming(result_then, mask_then_bb); + current_result = builder.CreatePHI(result_then->getType(), 2); + ((PHINode*)current_result)->addIncoming(result_then, mask_then_bb); Value *result_false = false_values->get_value(idx); - if(vector_size > 1) + if(result_then->getType()->isVectorTy()) result_false = builder.CreateVectorSplat(vector_size, result_false); - ((PHINode*)result)->addIncoming(result_false, current_bb); + ((PHINode*)current_result)->addIncoming(result_false, current_bb); } else - result = result_then; + current_result = result_then; // std::string offset = ""; // if(cst) @@ -1160,7 +1184,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & // InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true); // Value *result = builder.CreateCall(iasm, {mask, ptr}); - packets[id] = result; + packets[id] = current_result; } }); // extract result element diff --git a/lib/codegen/target.cpp b/lib/codegen/target.cpp index 2e20839d9..4116bcca7 100644 --- a/lib/codegen/target.cpp +++ b/lib/codegen/target.cpp @@ -32,6 +32,11 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un return result; } +Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) { + throw std::runtime_error("not implemented"); +} + + Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) { static std::array ids = { Intrinsic::amdgcn_workgroup_id_x, @@ -43,6 +48,16 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne return group_id; } +Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { + static std::array ids = { + Intrinsic::r600_read_ngroups_x, + Intrinsic::r600_read_ngroups_y, + Intrinsic::r600_read_ngroups_z + }; + Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]); + return builder.CreateCall(get_num_group, {}); +} + Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { static std::array ids = { Intrinsic::amdgcn_workitem_id_x, @@ -70,6 +85,12 @@ Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) return builder.CreateCall(barrier, {}); } +Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) { + Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl); + return builder.CreateCall(barrier, {}); +} + + Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { Value* group_id = get_block_id(module, builder, ax); Value* result = builder.CreateMul(builder.getInt32(stride), group_id); @@ -82,39 +103,39 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi Intrinsic::nvvm_read_ptx_sreg_ctaid_y, Intrinsic::nvvm_read_ptx_sreg_ctaid_z }; - bool z_order = true; - if(z_order && ax < 2){ - static std::array n_cta_ids = { - Intrinsic::nvvm_read_ptx_sreg_nctaid_x, - Intrinsic::nvvm_read_ptx_sreg_nctaid_y, - Intrinsic::nvvm_read_ptx_sreg_nctaid_z - }; - Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {}); - Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {}); - Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {}); - Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {}); - // global block ID - Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0)); - // helper for minimum - auto Min = [&](Value *x, Value *y){ - return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x); - }; - // super-tile size - Value* sts = Min(builder.getInt32(16), n_cta_id_1); - // number of CTAs per super-block - Value *nscta = builder.CreateMul(n_cta_id_0, sts); - Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0); - Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts)); - if(ax == 0) - return bid0; - else - return bid1; - } - else{ +// bool z_order = true; +// if(z_order && ax < 2){ +// static std::array n_cta_ids = { +// Intrinsic::nvvm_read_ptx_sreg_nctaid_x, +// Intrinsic::nvvm_read_ptx_sreg_nctaid_y, +// Intrinsic::nvvm_read_ptx_sreg_nctaid_z +// }; +// Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {}); +// Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {}); +// Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {}); +// Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {}); +// // global block ID +// Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0)); +// // helper for minimum +// auto Min = [&](Value *x, Value *y){ +// return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x); +// }; +// // super-tile size +// Value* sts = Min(builder.getInt32(16), n_cta_id_1); +// // number of CTAs per super-block +// Value *nscta = builder.CreateMul(n_cta_id_0, sts); +// Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0); +// Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts)); +// if(ax == 0) +// return bid0; +// else +// return bid1; +// } +// else{ Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]); Value* cta_id = builder.CreateCall(get_cta_id, {}); return cta_id; - } +// } } Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { @@ -127,6 +148,16 @@ Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsi return builder.CreateCall(get_local_id, {}); } +Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { + static std::array ids = { + Intrinsic::nvvm_read_ptx_sreg_nctaid_x, + Intrinsic::nvvm_read_ptx_sreg_nctaid_y, + Intrinsic::nvvm_read_ptx_sreg_nctaid_z + }; + Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]); + return builder.CreateCall(get_nctaid, {}); +} + // CPU void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) { @@ -138,6 +169,12 @@ Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) { return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0)); } +Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) { + // no barrier on CPU + return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0)); +} + + Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) { const Function *fn = builder.GetInsertBlock()->getParent(); size_t num_params = fn->getFunctionType()->getNumParams(); @@ -149,6 +186,11 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi return (Argument*)ids[ax]; } +Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) { + throw std::runtime_error("not implemented"); +} + + Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax)); return result; diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 72e0d340e..e5aa7ad45 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -6,6 +6,8 @@ namespace triton{ namespace dnn{ +namespace rt = triton::runtime; + void base::set_ld(const std::vector& shapes, std::vector& ld) { @@ -28,8 +30,7 @@ params_t base::heuristics() const { return *search_space().begin(); } -void base::enqueue(driver::stream *stream, std::vector args, autotuning_t autotune) { - namespace rt = triton::runtime; +std::pair base::get_profile_impl(driver::stream *stream, std::vector args, autotuning_t autotune) { static std::map, cmp_recompile> m_jit; driver::context* ctx = stream->context(); rt::jit* jit; @@ -67,16 +68,23 @@ void base::enqueue(driver::stream *stream, std::vector args, a clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); } /* retrieved compiled template */ - else{ + else { jit = m_jit.at(this).get(); } - - /* get launch parameters */ - driver::kernel* kernel = jit->get_function(name_.c_str()); - rt::launch_information info = jit->get_launch_info(name_.c_str()); - /* launch */ auto it = m_jit.find(this); - it->first->enqueue_impl(stream, kernel, args, info); + return {it->first, jit}; +} + +void base::enqueue(driver::stream *stream, std::vector args, autotuning_t autotune) { + launch_context_t info = get_launch_context(stream, args, autotune); + info.op->enqueue_impl(stream, info.kernel, args, info.info); +} + +launch_context_t base::get_launch_context(driver::stream *stream, std::vector args, autotuning_t autotune) { + std::pair profile = get_profile_impl(stream, args, autotune); + driver::kernel* kernel = profile.second->get_function(name_.c_str()); + rt::launch_information info = profile.second->get_launch_info(name_.c_str()); + return {profile.first, kernel, info}; } } diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp index 9ddb2514b..46a706498 100644 --- a/lib/dnn/blocksparse/dot.cpp +++ b/lib/dnn/blocksparse/dot.cpp @@ -13,26 +13,42 @@ bool dot::operator <(const base& other) const { auto *y = dynamic_cast(&other); if(!y) return true; - return std::tie(M_, N_, K_) - < std::tie(y->M_, y->N_, y->K_); + return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_) + < std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_); } std::vector dot::search_space() const { - + throw std::runtime_error("not implemented"); } params_t dot::heuristics() const { - + throw std::runtime_error("not implemented"); } base * dot::clone() const { return new dot(*this); } -dot::dot(int32_t M, int32_t N, int32_t K): - base("bsdot"), M_(M), N_(N), K_(K) { - ab_ty_ = "fp32"; - c_ty_ = "fp32"; +dot::dot(int32_t N, int32_t K, int32_t S, int32_t C, + const std::string& ty, int32_t BS, int32_t nlocks): + base("bsdot"), + N_(N), K_(K), S_(S), C_(C), + ab_ty_(ty), c_ty_(ty), + BS_(BS), nlocks_(nlocks) { +} + +void dot::init_impl(driver::stream *stream, driver::cu_module *module) { +// int32_t TM = info.globals["TM"]; +// size_t grid_0 = (N_ + TM - 1) / TM; +// if(nlocks_){ +// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4); +// ((driver::cu_buffer*)locks_)->set_zero(stream, grid_0 * nlocks_ * 2 * 4); +// } +} + +void dot::deinit_impl() { +// if(locks_) +// delete locks_; } void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, @@ -41,64 +57,89 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, driver::buffer *b = args[1]; driver::buffer *c = args[2]; driver::buffer *lut = args[3]; - int32_t lda = M_; - int32_t ldc = M_; + driver::buffer *locks = args[4]; + int32_t lda = N_; + int32_t ldc = N_; kernel->setArg(0, a); kernel->setArg(1, b); kernel->setArg(2, c); kernel->setArg(3, lda); kernel->setArg(4, ldc); - kernel->setArg(5, lut); + kernel->setArg(5, N_); + kernel->setArg(6, lut); + kernel->setArg(7, locks); + kernel->setArg(8, nlocks_); int32_t TM = info.globals["TM"]; - int32_t TN = info.globals["TN"]; - size_t grid_0 = (M_ + TM - 1) / TM; - size_t grid_1 = (N_ + TN - 1) / TN; + size_t grid_0 = (N_ + TM - 1) / TM; + size_t grid_1 = S_; + std::cout << N_ << " " << grid_0 << std::endl; + if(nlocks_){ +// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4); + ((driver::cu_buffer*)locks)->set_zero(stream, grid_0 * nlocks_ * 2 * 4); + } stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1}); - stream->synchronize(); } void dot::triton_c_src(std::ostream &os) const { std::string result = R"( - const tunable int32 TM = {64, 128}; - const tunable int32 TN = {32}; - const tunable int32 TK = {32}; + const tunable int32 TM = {64}; + const tunable int32 TN = {)" + std::to_string(BS_) + R"(}; + const tunable int32 TK = {)" + std::to_string(BS_) + R"(}; void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, restrict read_only align(16) )" + ab_ty_ + R"( *B, - fp32* C, - int32 lda, int32 ldc, - int32* lut_base){ + )" + c_ty_ + R"(* C, + int32 lda, int32 ldc, int32 N, + int32* lut, int32* locks, int32 nlocks){ int32 ridx = get_range_id(0); int32 ridy = get_range_id(1); - fp32 c[TM, TN] = 0; - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; + fp32 acc[TM, TN] = 0; int32 rxa[TM] = ridx * TM + (0 ... TM); int32 ryb[TN] = 0 ... TN; + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda; int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK; - int32 *header = lut_base + ridy * 4; + int32 *header = lut + ridy * 4; int32 offset = *(header + 0); int32 K = *(header + 1); - int32 h2 = *(header + 2); - int32 h3 = *(header + 3); - int32 *lut = lut_base + offset*2; + int32 column = *(header + 2); + int32 lockid = *(header + 3); + int32 *plut = lut + offset * 2; for(int32 k = K; k > 0; k = k - 1){ - int32 ak = *(lut + 0); - int32 bk = *(lut + 1); - fp32* pa[TM, TK] = A + offa + ak * TK * lda; - fp32* pb[TK, TN] = B + offb + bk * TK * TN; - fp32 a[TM, TK] = *pa; - fp32 b[TK, TN] = *pb;; - c = dot(a, b, c); - lut = lut + 2; + int32 ak = *(plut + 0); + int32 bk = *(plut + 1); + )" + ab_ty_ + R"(* pa[TM, TK] = A + offa + ak * TK * lda; + )" + ab_ty_ + R"(* pb[TK, TN] = B + offb + bk * TK * TN; + )" + ab_ty_ + R"( a[TM, TK] = *pa; + )" + ab_ty_ + R"( b[TK, TN] = *pb; + acc = dot(a, b, acc); + plut = plut + 2; } int32 rxc[TM] = ridx * TM + (0 ... TM); - int32 ryc[TN] = ridy * TN + (0 ... TN); - fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; - *pc = c; + int32 ryc[TN] = column * TN + (0 ... TN); + )" + c_ty_ + R"(" c[TM, TN] = acc; + )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; + int1 checkc[TM, TN] = (rxc < N)[:, newaxis]; + if(lockid == 0){ + @checkc *pc = c; + } + else{ + int32 *plock = locks + ridx*nlocks + lockid - 1; + int32 *pcount = plock + get_num_program(0)*nlocks; + while(__atomic_cas(plock, 0, 1)); + int32 count = *pcount; + if(count == 0) { + @checkc *pc = c; + } + else { + @checkc *pc = c + *pc; + } + *pcount = 1; + __atomic_exch(plock, 0); + } })"; os << result; diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index e58fd9924..77c099827 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -294,10 +294,18 @@ value *builder::create_get_range_id(unsigned axis, const std::string &name) { return insert(get_range_id_inst::create(ctx_, axis, name)); } +value *builder::create_get_num_program(unsigned axis, const std::string &name) { + return insert(get_num_program_inst::create(ctx_, axis, name)); +} + value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){ return insert(atomic_cas_inst::create(ptr, cmp, val, name)); } +value *builder::create_atomic_exch(value *ptr, value *val, const std::string &name){ + return insert(atomic_exch_inst::create(ptr, val, name)); +} + value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){ return insert(atomic_add_inst::create(ptr, val, name)); } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 9537336fb..a29c11914 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -636,6 +636,17 @@ instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::s return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next); } +// get_num_program +get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next) + : builtin_inst(ty, 0, 1, name, next), axis_(axis){ + +} + +instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { + return new get_num_program_inst(type::get_int32_ty(ctx), axis, name, next); +} + + // atomic cas atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) @@ -649,6 +660,18 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s return new atomic_cas_inst(ptr, cmp, val, name, next); } +// atomic exch + +atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next) + : builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) { + set_operand(0, ptr); + set_operand(1, val); +} + +instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string &name, instruction *next) { + return new atomic_exch_inst(ptr, val, name, next); +} + // atomic add atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next) diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index 6baa1f3b2..15e66607a 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -120,6 +120,11 @@ ir::value* get_range_id_expression::codegen(ir::module *mod) const { return mod->get_builder().create_get_range_id(axis_->value()); } +// get_num_program +ir::value* get_num_program_expression::codegen(ir::module *mod) const { + return mod->get_builder().create_get_num_program(axis_->value()); +} + // atomic cas ir::value* atomic_cas_expression::codegen(ir::module *mod) const { ir::value *ptr = ptr_->codegen(mod); @@ -128,6 +133,13 @@ ir::value* atomic_cas_expression::codegen(ir::module *mod) const { return mod->get_builder().create_atomic_cas(ptr, cmp, val); } +// atomic exch +ir::value* atomic_exch_expression::codegen(ir::module *mod) const { + ir::value *ptr = ptr_->codegen(mod); + ir::value *val = val_->codegen(mod); + return mod->get_builder().create_atomic_exch(ptr, val); +} + // atomic add ir::value* atomic_add_expression::codegen(ir::module *mod) const { ir::value *ptr = ptr_->codegen(mod);