[dnn/blocksparse] FPROP test passes!

This commit is contained in:
Philippe Tillet
2019-07-29 17:06:20 -07:00
parent 17cb2db356
commit dc11f70fad
20 changed files with 360 additions and 118 deletions

View File

@@ -101,6 +101,7 @@ typedef struct bsmm_params
CUstream stream;
} bsmm_params;
template<typename T>
class BlocksparseMatmulOp : public OpKernel {
public:
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -152,29 +153,23 @@ public:
shape_c.AddDim(params_.K);
Tensor* c = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
// grid and block
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
blkN = 64;
gridN = (N + 63)/64;
}
// allocate locks
int gridN = (N + 63)/64;
Tensor* locks;
TensorShape shape_l;
if (params_.locks > 0)
shape_l.AddDim(gridN * params_.locks * 2);
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
// initialize default compute device
triton::runtime::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
// wrap tensorflow handles
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<T>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<T>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<T>().data(), false);
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
// create profile
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks);
// blocksparse matmul
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
}
private:
@@ -185,4 +180,5 @@ private:
char bench_string_[256];
};
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp<float>);
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), BlocksparseMatmulOp<Eigen::half>);

View File

@@ -23,9 +23,11 @@ public:
virtual ~target() {}
virtual void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) = 0;
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
virtual llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
virtual llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
bool is_gpu() const;
private:
@@ -37,9 +39,11 @@ public:
amd_cl_target(): target(true){}
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
};
class nvidia_cu_target: public target {
@@ -47,9 +51,11 @@ public:
nvidia_cu_target(): target(true){}
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
};
class cpu_target: public target {
@@ -57,9 +63,11 @@ public:
cpu_target(): target(false){}
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
};
}

View File

@@ -28,6 +28,11 @@
#include "triton/runtime/launch_info.h"
namespace triton{
namespace runtime{
class jit;
}
namespace dnn{
@@ -37,6 +42,13 @@ enum autotuning_t{
NO_TUNING
};
class base;
struct launch_context_t{
base *op;
driver::kernel* kernel;
triton::runtime::launch_information info;
};
typedef std::vector<unsigned> params_t;
class base {
@@ -49,9 +61,9 @@ protected:
private:
// initialize
virtual void init_impl(driver::stream *, driver::cu_module *){ }
virtual void init_impl(driver::stream *, driver::cu_module *) = 0;
// deinitialize
virtual void deinit_impl(){ }
virtual void deinit_impl() = 0;
// enqueue
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
@@ -63,6 +75,8 @@ private:
// default parameters
virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const;
// obtain execution jit
std::pair<base*, triton::runtime::jit*> get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune);
public:
// constructor
@@ -73,6 +87,8 @@ public:
virtual base* clone() const = 0;
// enqueue
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, autotuning_t autotune = PARTIAL_TUNING);
// get profile
launch_context_t get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune = PARTIAL_TUNING);
private:
std::string name_;

View File

@@ -37,6 +37,10 @@ namespace dnn{
class batchnorm_forward: public base {
private:
// init
void init_impl(driver::stream *, driver::cu_module *) { }
void deinit_impl() { }
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
@@ -69,6 +73,9 @@ private:
class batchnorm_backward: public base{
private:
// init
void init_impl(driver::stream *, driver::cu_module *) { }
void deinit_impl() { }
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,

View File

@@ -14,27 +14,34 @@ private:
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
// number of flops
virtual size_t num_flops() const;
size_t num_flops() const;
// comparison for maps
virtual bool operator<(const base& other) const;
bool operator<(const base& other) const;
// default parameters
virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const;
std::vector<params_t> search_space() const;
params_t heuristics() const;
// init
void init_impl(driver::stream *stream, driver::cu_module *module);
// deinit
void deinit_impl();
public:
// constructor
dot(int32_t M, int32_t N, int32_t K);
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks);
// triton-c source
virtual void triton_c_src(std::ostream &os) const;
void triton_c_src(std::ostream &os) const;
// clone
virtual base* clone() const;
base* clone() const;
private:
std::string ab_ty_;
std::string c_ty_;
int32_t M_;
int32_t N_;
int32_t S_;
int32_t C_;
int32_t K_;
int32_t BS_;
int32_t nlocks_;
driver::buffer *locks_;
};
}

View File

@@ -25,6 +25,7 @@ private:
void build_a_deltas();
void build_masks();
void init_impl(driver::stream *, driver::cu_module *);
void deinit_impl() { }
// enqueue
std::array<size_t, 3> get_grid(size_t TM, size_t TN);

View File

@@ -10,6 +10,8 @@ class dot: public base {
private:
// initialize
void init_impl(driver::stream *, driver::cu_module *);
void deinit_impl() { }
// enqueue
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,

View File

@@ -126,9 +126,10 @@ public:
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
// Built-in instruction
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
value *create_get_range_id(unsigned axis, const std::string &name = "");
value *create_get_num_program(unsigned axis, const std::string &name = "");
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::string &name = "");

View File

@@ -514,6 +514,19 @@ private:
unsigned axis_;
};
class get_num_program_inst: public builtin_inst {
private:
get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
private:
unsigned axis_;
};
class atomic_cas_inst: public builtin_inst {
private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
@@ -523,6 +536,15 @@ public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
};
class atomic_exch_inst: public builtin_inst {
private:
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_exch"; }
public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
};
class atomic_add_inst: public builtin_inst {
private:
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);

View File

@@ -80,6 +80,15 @@ private:
const constant* axis_;
};
class get_num_program_expression: public builtin_expression{
public:
get_num_program_expression(node *axis): axis_((constant*)axis) { }
ir::value* codegen(ir::module *mod) const;
private:
const constant* axis_;
};
class atomic_cas_expression: public builtin_expression{
public:
atomic_cas_expression(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { }
@@ -91,6 +100,17 @@ private:
const node *val_;
};
class atomic_exch_expression: public builtin_expression{
public:
atomic_exch_expression(node *ptr, node *val): ptr_(ptr), val_(val) { }
ir::value* codegen(ir::module *) const;
private:
const node *ptr_;
const node *val_;
};
class atomic_add_expression: public builtin_expression{
public:
atomic_add_expression(node *ptr, node *val): ptr_(ptr), val_(val) { }

View File

@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
%token IF ELSE FOR CONTINUE WHILE
%token NEWAXIS ELLIPSIS AT
%token GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST
%start translation_unit
%%
@@ -121,6 +121,7 @@ identifier
/* Built-in */
builtin_expression
: GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
| GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); }
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
@@ -130,6 +131,7 @@ builtin_expression
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
;

View File

@@ -45,8 +45,9 @@ using triton::lang::return_void;
"fp64" { return return_impl(FP64, yytext); }
"..." { return return_impl(ELLIPSIS, yytext); }
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
"__sum" { return return_impl(REDUCE_SUM, yytext); }
"sqrt" { return return_impl(SQRT, yytext); }

View File

@@ -19,7 +19,8 @@ void optimize_dce::run(ir::module &mod) {
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)){
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)
|| dynamic_cast<ir::atomic_cas_inst*>(i) || dynamic_cast<ir::atomic_exch_inst*>(i) || dynamic_cast<ir::atomic_add_inst*>(i) ){
work_list.push_back(i);
marked.insert(i);
}

View File

@@ -319,8 +319,12 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
}
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)offset;
Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)result;
}
if(ir::get_num_program_inst* ii = dynamic_cast<ir::get_num_program_inst*>(inst)){
Value *result = tgt_->get_num_blocks(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)result;
}
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
BasicBlock *current = builder.GetInsertBlock();
@@ -331,6 +335,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii)));
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
tgt_->add_memfence(module, builder);
tgt_->add_barrier(module, builder);
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder.SetInsertPoint(tid_0_bb);
@@ -342,10 +347,29 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
builder.CreateStore(old, ptr);
builder.CreateBr(tid_0_done_bb);
builder.SetInsertPoint(tid_0_done_bb);
tgt_->add_memfence(module, builder);
tgt_->add_barrier(module, builder);
Value *res = builder.CreateLoad(ptr);
return (Instruction*)res;
}
if(ir::atomic_exch_inst* ii = dynamic_cast<ir::atomic_exch_inst*>(inst)){
BasicBlock *current = builder.GetInsertBlock();
Module *module = current->getModule();
Value *rmw_ptr = value(ii->get_operand(0));
Value *rmw_val = value(ii->get_operand(1));
Value *tid = tgt_->get_local_id(module, builder, 0);
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
tgt_->add_memfence(module, builder);
tgt_->add_barrier(module, builder);
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
builder.SetInsertPoint(tid_0_bb);
Value *res = builder.CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System);
builder.CreateBr(tid_0_done_bb);
builder.SetInsertPoint(tid_0_done_bb);
return (Instruction*)res;
}
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
Value *ptr = value(ii->get_operand(0));
Value *val = value(ii->get_operand(1));
@@ -1136,17 +1160,17 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *result_then = builder.CreateLoad(ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
Value *result = nullptr;
Value *current_result = nullptr;
if(false_values){
result = builder.CreatePHI(result_then->getType(), 2);
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
current_result = builder.CreatePHI(result_then->getType(), 2);
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
Value *result_false = false_values->get_value(idx);
if(vector_size > 1)
if(result_then->getType()->isVectorTy())
result_false = builder.CreateVectorSplat(vector_size, result_false);
((PHINode*)result)->addIncoming(result_false, current_bb);
((PHINode*)current_result)->addIncoming(result_false, current_bb);
}
else
result = result_then;
current_result = result_then;
// std::string offset = "";
// if(cst)
@@ -1160,7 +1184,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
// Value *result = builder.CreateCall(iasm, {mask, ptr});
packets[id] = result;
packets[id] = current_result;
}
});
// extract result element

View File

@@ -32,6 +32,11 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
return result;
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
}
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
@@ -43,6 +48,16 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
return group_id;
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::r600_read_ngroups_x,
Intrinsic::r600_read_ngroups_y,
Intrinsic::r600_read_ngroups_z
};
Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_num_group, {});
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
@@ -70,6 +85,12 @@ Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder)
return builder.CreateCall(barrier, {});
}
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
return builder.CreateCall(barrier, {});
}
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
@@ -82,39 +103,39 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
bool z_order = true;
if(z_order && ax < 2){
static std::array<Intrinsic::ID, 3> n_cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
// global block ID
Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
// helper for minimum
auto Min = [&](Value *x, Value *y){
return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
};
// super-tile size
Value* sts = Min(builder.getInt32(16), n_cta_id_1);
// number of CTAs per super-block
Value *nscta = builder.CreateMul(n_cta_id_0, sts);
Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
if(ax == 0)
return bid0;
else
return bid1;
}
else{
// bool z_order = true;
// if(z_order && ax < 2){
// static std::array<Intrinsic::ID, 3> n_cta_ids = {
// Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
// Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
// Intrinsic::nvvm_read_ptx_sreg_nctaid_z
// };
// Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
// Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
// Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
// Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
// // global block ID
// Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
// // helper for minimum
// auto Min = [&](Value *x, Value *y){
// return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
// };
// // super-tile size
// Value* sts = Min(builder.getInt32(16), n_cta_id_1);
// // number of CTAs per super-block
// Value *nscta = builder.CreateMul(n_cta_id_0, sts);
// Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
// Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
// if(ax == 0)
// return bid0;
// else
// return bid1;
// }
// else{
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id;
}
// }
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
@@ -127,6 +148,16 @@ Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsi
return builder.CreateCall(get_local_id, {});
}
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_nctaid, {});
}
// CPU
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
@@ -138,6 +169,12 @@ Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
const Function *fn = builder.GetInsertBlock()->getParent();
size_t num_params = fn->getFunctionType()->getNumParams();
@@ -149,6 +186,11 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
return (Argument*)ids[ax];
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
return result;

View File

@@ -6,6 +6,8 @@
namespace triton{
namespace dnn{
namespace rt = triton::runtime;
void base::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
@@ -28,8 +30,7 @@ params_t base::heuristics() const {
return *search_space().begin();
}
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
namespace rt = triton::runtime;
std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
driver::context* ctx = stream->context();
rt::jit* jit;
@@ -67,16 +68,23 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
}
/* retrieved compiled template */
else{
else {
jit = m_jit.at(this).get();
}
/* get launch parameters */
driver::kernel* kernel = jit->get_function(name_.c_str());
rt::launch_information info = jit->get_launch_info(name_.c_str());
/* launch */
auto it = m_jit.find(this);
it->first->enqueue_impl(stream, kernel, args, info);
return {it->first, jit};
}
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
launch_context_t info = get_launch_context(stream, args, autotune);
info.op->enqueue_impl(stream, info.kernel, args, info.info);
}
launch_context_t base::get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
std::pair<base*, rt::jit*> profile = get_profile_impl(stream, args, autotune);
driver::kernel* kernel = profile.second->get_function(name_.c_str());
rt::launch_information info = profile.second->get_launch_info(name_.c_str());
return {profile.first, kernel, info};
}
}

View File

@@ -13,26 +13,42 @@ bool dot::operator <(const base& other) const {
auto *y = dynamic_cast<const dot*>(&other);
if(!y)
return true;
return std::tie(M_, N_, K_)
< std::tie(y->M_, y->N_, y->K_);
return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_)
< std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_);
}
std::vector<params_t> dot::search_space() const {
throw std::runtime_error("not implemented");
}
params_t dot::heuristics() const {
throw std::runtime_error("not implemented");
}
base * dot::clone() const {
return new dot(*this);
}
dot::dot(int32_t M, int32_t N, int32_t K):
base("bsdot"), M_(M), N_(N), K_(K) {
ab_ty_ = "fp32";
c_ty_ = "fp32";
dot::dot(int32_t N, int32_t K, int32_t S, int32_t C,
const std::string& ty, int32_t BS, int32_t nlocks):
base("bsdot"),
N_(N), K_(K), S_(S), C_(C),
ab_ty_(ty), c_ty_(ty),
BS_(BS), nlocks_(nlocks) {
}
void dot::init_impl(driver::stream *stream, driver::cu_module *module) {
// int32_t TM = info.globals["TM"];
// size_t grid_0 = (N_ + TM - 1) / TM;
// if(nlocks_){
// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4);
// ((driver::cu_buffer*)locks_)->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
// }
}
void dot::deinit_impl() {
// if(locks_)
// delete locks_;
}
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
@@ -41,64 +57,89 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
driver::buffer *b = args[1];
driver::buffer *c = args[2];
driver::buffer *lut = args[3];
int32_t lda = M_;
int32_t ldc = M_;
driver::buffer *locks = args[4];
int32_t lda = N_;
int32_t ldc = N_;
kernel->setArg(0, a);
kernel->setArg(1, b);
kernel->setArg(2, c);
kernel->setArg(3, lda);
kernel->setArg(4, ldc);
kernel->setArg(5, lut);
kernel->setArg(5, N_);
kernel->setArg(6, lut);
kernel->setArg(7, locks);
kernel->setArg(8, nlocks_);
int32_t TM = info.globals["TM"];
int32_t TN = info.globals["TN"];
size_t grid_0 = (M_ + TM - 1) / TM;
size_t grid_1 = (N_ + TN - 1) / TN;
size_t grid_0 = (N_ + TM - 1) / TM;
size_t grid_1 = S_;
std::cout << N_ << " " << grid_0 << std::endl;
if(nlocks_){
// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4);
((driver::cu_buffer*)locks)->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
}
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
stream->synchronize();
}
void dot::triton_c_src(std::ostream &os) const {
std::string result =
R"(
const tunable int32 TM = {64, 128};
const tunable int32 TN = {32};
const tunable int32 TK = {32};
const tunable int32 TM = {64};
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
restrict read_only align(16) )" + ab_ty_ + R"( *B,
fp32* C,
int32 lda, int32 ldc,
int32* lut_base){
)" + c_ty_ + R"(* C,
int32 lda, int32 ldc, int32 N,
int32* lut, int32* locks, int32 nlocks){
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32 c[TM, TN] = 0;
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 acc[TM, TN] = 0;
int32 rxa[TM] = ridx * TM + (0 ... TM);
int32 ryb[TN] = 0 ... TN;
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda;
int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK;
int32 *header = lut_base + ridy * 4;
int32 *header = lut + ridy * 4;
int32 offset = *(header + 0);
int32 K = *(header + 1);
int32 h2 = *(header + 2);
int32 h3 = *(header + 3);
int32 *lut = lut_base + offset*2;
int32 column = *(header + 2);
int32 lockid = *(header + 3);
int32 *plut = lut + offset * 2;
for(int32 k = K; k > 0; k = k - 1){
int32 ak = *(lut + 0);
int32 bk = *(lut + 1);
fp32* pa[TM, TK] = A + offa + ak * TK * lda;
fp32* pb[TK, TN] = B + offb + bk * TK * TN;
fp32 a[TM, TK] = *pa;
fp32 b[TK, TN] = *pb;;
c = dot(a, b, c);
lut = lut + 2;
int32 ak = *(plut + 0);
int32 bk = *(plut + 1);
)" + ab_ty_ + R"(* pa[TM, TK] = A + offa + ak * TK * lda;
)" + ab_ty_ + R"(* pb[TK, TN] = B + offb + bk * TK * TN;
)" + ab_ty_ + R"( a[TM, TK] = *pa;
)" + ab_ty_ + R"( b[TK, TN] = *pb;
acc = dot(a, b, acc);
plut = plut + 2;
}
int32 rxc[TM] = ridx * TM + (0 ... TM);
int32 ryc[TN] = ridy * TN + (0 ... TN);
fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
*pc = c;
int32 ryc[TN] = column * TN + (0 ... TN);
)" + c_ty_ + R"(" c[TM, TN] = acc;
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
int1 checkc[TM, TN] = (rxc < N)[:, newaxis];
if(lockid == 0){
@checkc *pc = c;
}
else{
int32 *plock = locks + ridx*nlocks + lockid - 1;
int32 *pcount = plock + get_num_program(0)*nlocks;
while(__atomic_cas(plock, 0, 1));
int32 count = *pcount;
if(count == 0) {
@checkc *pc = c;
}
else {
@checkc *pc = c + *pc;
}
*pcount = 1;
__atomic_exch(plock, 0);
}
})";
os << result;

View File

@@ -294,10 +294,18 @@ value *builder::create_get_range_id(unsigned axis, const std::string &name) {
return insert(get_range_id_inst::create(ctx_, axis, name));
}
value *builder::create_get_num_program(unsigned axis, const std::string &name) {
return insert(get_num_program_inst::create(ctx_, axis, name));
}
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){
return insert(atomic_cas_inst::create(ptr, cmp, val, name));
}
value *builder::create_atomic_exch(value *ptr, value *val, const std::string &name){
return insert(atomic_exch_inst::create(ptr, val, name));
}
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
return insert(atomic_add_inst::create(ptr, val, name));
}

View File

@@ -636,6 +636,17 @@ instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::s
return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next);
}
// get_num_program
get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
}
instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
return new get_num_program_inst(type::get_int32_ty(ctx), axis, name, next);
}
// atomic cas
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
@@ -649,6 +660,18 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
return new atomic_cas_inst(ptr, cmp, val, name, next);
}
// atomic exch
atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}
instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
return new atomic_exch_inst(ptr, val, name, next);
}
// atomic add
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)

View File

@@ -120,6 +120,11 @@ ir::value* get_range_id_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_get_range_id(axis_->value());
}
// get_num_program
ir::value* get_num_program_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_get_num_program(axis_->value());
}
// atomic cas
ir::value* atomic_cas_expression::codegen(ir::module *mod) const {
ir::value *ptr = ptr_->codegen(mod);
@@ -128,6 +133,13 @@ ir::value* atomic_cas_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
}
// atomic exch
ir::value* atomic_exch_expression::codegen(ir::module *mod) const {
ir::value *ptr = ptr_->codegen(mod);
ir::value *val = val_->codegen(mod);
return mod->get_builder().create_atomic_exch(ptr, val);
}
// atomic add
ir::value* atomic_add_expression::codegen(ir::module *mod) const {
ir::value *ptr = ptr_->codegen(mod);