[dnn/blocksparse] FPROP test passes!
This commit is contained in:
@@ -101,6 +101,7 @@ typedef struct bsmm_params
|
||||
CUstream stream;
|
||||
} bsmm_params;
|
||||
|
||||
template<typename T>
|
||||
class BlocksparseMatmulOp : public OpKernel {
|
||||
public:
|
||||
explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@@ -152,29 +153,23 @@ public:
|
||||
shape_c.AddDim(params_.K);
|
||||
Tensor* c = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, shape_c, &c));
|
||||
// grid and block
|
||||
int blkN = 128, gridN = (N + 127)/128, modN128 = N & 127;
|
||||
if (axis_ == 1 || (modN128 > 0 && modN128 <= 64) || gridN * params_.segments < SMs_*4){
|
||||
blkN = 64;
|
||||
gridN = (N + 63)/64;
|
||||
}
|
||||
// allocate locks
|
||||
int gridN = (N + 63)/64;
|
||||
Tensor* locks;
|
||||
TensorShape shape_l;
|
||||
if (params_.locks > 0)
|
||||
shape_l.AddDim(gridN * params_.locks * 2);
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, shape_l, &locks));
|
||||
// initialize default compute device
|
||||
triton::runtime::jit jit(ctx);
|
||||
// matrix multiplication parameters
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||
// triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
// wrap tensorflow handles
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<T>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<T>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<T>().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, (CUdeviceptr)lut.flat<int64>().data(), false);
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks->flat<int32>().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp32", params_.bsize, params_.locks);
|
||||
// blocksparse matmul
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.C);
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
dot.enqueue(stream, {&da, &db, &dc, &dlut, &dlocks}, triton::dnn::NO_TUNING);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -185,4 +180,5 @@ private:
|
||||
char bench_string_[256];
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlocksparseMatmulOp<float>);
|
||||
REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmul").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), BlocksparseMatmulOp<Eigen::half>);
|
||||
|
@@ -23,9 +23,11 @@ public:
|
||||
virtual ~target() {}
|
||||
virtual void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn) = 0;
|
||||
virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
|
||||
virtual llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder) = 0;
|
||||
virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
virtual llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0;
|
||||
bool is_gpu() const;
|
||||
|
||||
private:
|
||||
@@ -37,9 +39,11 @@ public:
|
||||
amd_cl_target(): target(true){}
|
||||
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
};
|
||||
|
||||
class nvidia_cu_target: public target {
|
||||
@@ -47,9 +51,11 @@ public:
|
||||
nvidia_cu_target(): target(true){}
|
||||
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
};
|
||||
|
||||
class cpu_target: public target {
|
||||
@@ -57,9 +63,11 @@ public:
|
||||
cpu_target(): target(false){}
|
||||
void set_kernel(llvm::IRBuilder<>& builder, llvm::LLVMContext &ctx, llvm::Module *module, llvm::Function* fn);
|
||||
llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Instruction* add_memfence(llvm::Module *module, llvm::IRBuilder<>& builder);
|
||||
llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax);
|
||||
llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
llvm::Value* get_num_blocks(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax);
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -28,6 +28,11 @@
|
||||
#include "triton/runtime/launch_info.h"
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace runtime{
|
||||
class jit;
|
||||
}
|
||||
|
||||
namespace dnn{
|
||||
|
||||
|
||||
@@ -37,6 +42,13 @@ enum autotuning_t{
|
||||
NO_TUNING
|
||||
};
|
||||
|
||||
class base;
|
||||
struct launch_context_t{
|
||||
base *op;
|
||||
driver::kernel* kernel;
|
||||
triton::runtime::launch_information info;
|
||||
};
|
||||
|
||||
typedef std::vector<unsigned> params_t;
|
||||
|
||||
class base {
|
||||
@@ -49,9 +61,9 @@ protected:
|
||||
|
||||
private:
|
||||
// initialize
|
||||
virtual void init_impl(driver::stream *, driver::cu_module *){ }
|
||||
virtual void init_impl(driver::stream *, driver::cu_module *) = 0;
|
||||
// deinitialize
|
||||
virtual void deinit_impl(){ }
|
||||
virtual void deinit_impl() = 0;
|
||||
// enqueue
|
||||
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
@@ -63,6 +75,8 @@ private:
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
// obtain execution jit
|
||||
std::pair<base*, triton::runtime::jit*> get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune);
|
||||
|
||||
public:
|
||||
// constructor
|
||||
@@ -73,6 +87,8 @@ public:
|
||||
virtual base* clone() const = 0;
|
||||
// enqueue
|
||||
void enqueue(driver::stream* stream, std::vector<driver::buffer*> args, autotuning_t autotune = PARTIAL_TUNING);
|
||||
// get profile
|
||||
launch_context_t get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune = PARTIAL_TUNING);
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
|
@@ -37,6 +37,10 @@ namespace dnn{
|
||||
|
||||
class batchnorm_forward: public base {
|
||||
private:
|
||||
// init
|
||||
void init_impl(driver::stream *, driver::cu_module *) { }
|
||||
void deinit_impl() { }
|
||||
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
@@ -69,6 +73,9 @@ private:
|
||||
|
||||
class batchnorm_backward: public base{
|
||||
private:
|
||||
// init
|
||||
void init_impl(driver::stream *, driver::cu_module *) { }
|
||||
void deinit_impl() { }
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
|
@@ -14,27 +14,34 @@ private:
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
virtual size_t num_flops() const;
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
virtual bool operator<(const base& other) const;
|
||||
bool operator<(const base& other) const;
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
|
||||
std::vector<params_t> search_space() const;
|
||||
params_t heuristics() const;
|
||||
// init
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||
// deinit
|
||||
void deinit_impl();
|
||||
public:
|
||||
// constructor
|
||||
dot(int32_t M, int32_t N, int32_t K);
|
||||
dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks);
|
||||
// triton-c source
|
||||
virtual void triton_c_src(std::ostream &os) const;
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
// clone
|
||||
virtual base* clone() const;
|
||||
base* clone() const;
|
||||
|
||||
private:
|
||||
std::string ab_ty_;
|
||||
std::string c_ty_;
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t S_;
|
||||
int32_t C_;
|
||||
int32_t K_;
|
||||
int32_t BS_;
|
||||
int32_t nlocks_;
|
||||
driver::buffer *locks_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -25,6 +25,7 @@ private:
|
||||
void build_a_deltas();
|
||||
void build_masks();
|
||||
void init_impl(driver::stream *, driver::cu_module *);
|
||||
void deinit_impl() { }
|
||||
|
||||
// enqueue
|
||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
|
||||
|
@@ -10,6 +10,8 @@ class dot: public base {
|
||||
private:
|
||||
// initialize
|
||||
void init_impl(driver::stream *, driver::cu_module *);
|
||||
void deinit_impl() { }
|
||||
|
||||
// enqueue
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
|
@@ -126,9 +126,10 @@ public:
|
||||
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
// Built-in instruction
|
||||
value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = "");
|
||||
value *create_get_range_id(unsigned axis, const std::string &name = "");
|
||||
value *create_get_num_program(unsigned axis, const std::string &name = "");
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::string &name = "");
|
||||
|
@@ -514,6 +514,19 @@ private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class get_num_program_inst: public builtin_inst {
|
||||
private:
|
||||
get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class atomic_cas_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
@@ -523,6 +536,15 @@ public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class atomic_exch_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_exch"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class atomic_add_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
|
@@ -80,6 +80,15 @@ private:
|
||||
const constant* axis_;
|
||||
};
|
||||
|
||||
class get_num_program_expression: public builtin_expression{
|
||||
public:
|
||||
get_num_program_expression(node *axis): axis_((constant*)axis) { }
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
const constant* axis_;
|
||||
};
|
||||
|
||||
class atomic_cas_expression: public builtin_expression{
|
||||
public:
|
||||
atomic_cas_expression(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { }
|
||||
@@ -91,6 +100,17 @@ private:
|
||||
const node *val_;
|
||||
};
|
||||
|
||||
class atomic_exch_expression: public builtin_expression{
|
||||
public:
|
||||
atomic_exch_expression(node *ptr, node *val): ptr_(ptr), val_(val) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const node *ptr_;
|
||||
const node *val_;
|
||||
};
|
||||
|
||||
|
||||
class atomic_add_expression: public builtin_expression{
|
||||
public:
|
||||
atomic_add_expression(node *ptr, node *val): ptr_(ptr), val_(val) { }
|
||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||
%token IF ELSE FOR CONTINUE WHILE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -121,6 +121,7 @@ identifier
|
||||
/* Built-in */
|
||||
builtin_expression
|
||||
: GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
||||
| GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); }
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||
@@ -130,6 +131,7 @@ builtin_expression
|
||||
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
||||
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
||||
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
||||
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
|
||||
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
|
||||
;
|
||||
|
||||
|
@@ -45,8 +45,9 @@ using triton::lang::return_void;
|
||||
"fp64" { return return_impl(FP64, yytext); }
|
||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
||||
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
|
||||
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
||||
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
||||
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }
|
||||
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
|
||||
"__sum" { return return_impl(REDUCE_SUM, yytext); }
|
||||
"sqrt" { return return_impl(SQRT, yytext); }
|
||||
|
@@ -19,7 +19,8 @@ void optimize_dce::run(ir::module &mod) {
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
|
||||
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)){
|
||||
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)
|
||||
|| dynamic_cast<ir::atomic_cas_inst*>(i) || dynamic_cast<ir::atomic_exch_inst*>(i) || dynamic_cast<ir::atomic_add_inst*>(i) ){
|
||||
work_list.push_back(i);
|
||||
marked.insert(i);
|
||||
}
|
||||
|
@@ -319,8 +319,12 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
|
||||
}
|
||||
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
||||
Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||
return (Instruction*)offset;
|
||||
Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||
return (Instruction*)result;
|
||||
}
|
||||
if(ir::get_num_program_inst* ii = dynamic_cast<ir::get_num_program_inst*>(inst)){
|
||||
Value *result = tgt_->get_num_blocks(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||
return (Instruction*)result;
|
||||
}
|
||||
if(ir::atomic_cas_inst* ii = dynamic_cast<ir::atomic_cas_inst*>(inst)){
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
@@ -331,6 +335,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
|
||||
Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii)));
|
||||
ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace()));
|
||||
tgt_->add_memfence(module, builder);
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_bb);
|
||||
@@ -342,10 +347,29 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
builder.CreateStore(old, ptr);
|
||||
builder.CreateBr(tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_done_bb);
|
||||
tgt_->add_memfence(module, builder);
|
||||
tgt_->add_barrier(module, builder);
|
||||
Value *res = builder.CreateLoad(ptr);
|
||||
return (Instruction*)res;
|
||||
}
|
||||
if(ir::atomic_exch_inst* ii = dynamic_cast<ir::atomic_exch_inst*>(inst)){
|
||||
BasicBlock *current = builder.GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *rmw_ptr = value(ii->get_operand(0));
|
||||
Value *rmw_val = value(ii->get_operand(1));
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0));
|
||||
BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent());
|
||||
BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent());
|
||||
tgt_->add_memfence(module, builder);
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_bb);
|
||||
Value *res = builder.CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System);
|
||||
builder.CreateBr(tid_0_done_bb);
|
||||
builder.SetInsertPoint(tid_0_done_bb);
|
||||
return (Instruction*)res;
|
||||
}
|
||||
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
|
||||
Value *ptr = value(ii->get_operand(0));
|
||||
Value *val = value(ii->get_operand(1));
|
||||
@@ -1136,17 +1160,17 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
Value *result_then = builder.CreateLoad(ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
Value *result = nullptr;
|
||||
Value *current_result = nullptr;
|
||||
if(false_values){
|
||||
result = builder.CreatePHI(result_then->getType(), 2);
|
||||
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
|
||||
current_result = builder.CreatePHI(result_then->getType(), 2);
|
||||
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
|
||||
Value *result_false = false_values->get_value(idx);
|
||||
if(vector_size > 1)
|
||||
if(result_then->getType()->isVectorTy())
|
||||
result_false = builder.CreateVectorSplat(vector_size, result_false);
|
||||
((PHINode*)result)->addIncoming(result_false, current_bb);
|
||||
((PHINode*)current_result)->addIncoming(result_false, current_bb);
|
||||
}
|
||||
else
|
||||
result = result_then;
|
||||
current_result = result_then;
|
||||
|
||||
// std::string offset = "";
|
||||
// if(cst)
|
||||
@@ -1160,7 +1184,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true);
|
||||
// Value *result = builder.CreateCall(iasm, {mask, ptr});
|
||||
|
||||
packets[id] = result;
|
||||
packets[id] = current_result;
|
||||
}
|
||||
});
|
||||
// extract result element
|
||||
|
@@ -32,6 +32,11 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
|
||||
return result;
|
||||
}
|
||||
|
||||
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workgroup_id_x,
|
||||
@@ -43,6 +48,16 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
||||
return group_id;
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::r600_read_ngroups_x,
|
||||
Intrinsic::r600_read_ngroups_y,
|
||||
Intrinsic::r600_read_ngroups_z
|
||||
};
|
||||
Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_num_group, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::amdgcn_workitem_id_x,
|
||||
@@ -70,6 +85,12 @@ Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder)
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
|
||||
return builder.CreateCall(barrier, {});
|
||||
}
|
||||
|
||||
|
||||
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* group_id = get_block_id(module, builder, ax);
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
|
||||
@@ -82,39 +103,39 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
bool z_order = true;
|
||||
if(z_order && ax < 2){
|
||||
static std::array<Intrinsic::ID, 3> n_cta_ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
|
||||
Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
|
||||
Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
|
||||
Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
|
||||
// global block ID
|
||||
Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
|
||||
// helper for minimum
|
||||
auto Min = [&](Value *x, Value *y){
|
||||
return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
|
||||
};
|
||||
// super-tile size
|
||||
Value* sts = Min(builder.getInt32(16), n_cta_id_1);
|
||||
// number of CTAs per super-block
|
||||
Value *nscta = builder.CreateMul(n_cta_id_0, sts);
|
||||
Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
|
||||
Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
|
||||
if(ax == 0)
|
||||
return bid0;
|
||||
else
|
||||
return bid1;
|
||||
}
|
||||
else{
|
||||
// bool z_order = true;
|
||||
// if(z_order && ax < 2){
|
||||
// static std::array<Intrinsic::ID, 3> n_cta_ids = {
|
||||
// Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
|
||||
// Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
// Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
// };
|
||||
// Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
|
||||
// Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
|
||||
// Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
|
||||
// Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
|
||||
// // global block ID
|
||||
// Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
|
||||
// // helper for minimum
|
||||
// auto Min = [&](Value *x, Value *y){
|
||||
// return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
|
||||
// };
|
||||
// // super-tile size
|
||||
// Value* sts = Min(builder.getInt32(16), n_cta_id_1);
|
||||
// // number of CTAs per super-block
|
||||
// Value *nscta = builder.CreateMul(n_cta_id_0, sts);
|
||||
// Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
|
||||
// Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
|
||||
// if(ax == 0)
|
||||
// return bid0;
|
||||
// else
|
||||
// return bid1;
|
||||
// }
|
||||
// else{
|
||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
||||
return cta_id;
|
||||
}
|
||||
// }
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
@@ -127,6 +148,16 @@ Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsi
|
||||
return builder.CreateCall(get_local_id, {});
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
static std::array<Intrinsic::ID, 3> ids = {
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_nctaid, {});
|
||||
}
|
||||
|
||||
// CPU
|
||||
|
||||
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
|
||||
@@ -138,6 +169,12 @@ Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
|
||||
// no barrier on CPU
|
||||
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
|
||||
const Function *fn = builder.GetInsertBlock()->getParent();
|
||||
size_t num_params = fn->getFunctionType()->getNumParams();
|
||||
@@ -149,6 +186,11 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
|
||||
return (Argument*)ids[ax];
|
||||
}
|
||||
|
||||
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
|
||||
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
|
||||
return result;
|
||||
|
@@ -6,6 +6,8 @@
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
|
||||
void base::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
@@ -28,8 +30,7 @@ params_t base::heuristics() const {
|
||||
return *search_space().begin();
|
||||
}
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
namespace rt = triton::runtime;
|
||||
std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
|
||||
driver::context* ctx = stream->context();
|
||||
rt::jit* jit;
|
||||
@@ -67,16 +68,23 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, a
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
/* retrieved compiled template */
|
||||
else{
|
||||
else {
|
||||
jit = m_jit.at(this).get();
|
||||
}
|
||||
|
||||
/* get launch parameters */
|
||||
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
rt::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
/* launch */
|
||||
auto it = m_jit.find(this);
|
||||
it->first->enqueue_impl(stream, kernel, args, info);
|
||||
return {it->first, jit};
|
||||
}
|
||||
|
||||
void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
launch_context_t info = get_launch_context(stream, args, autotune);
|
||||
info.op->enqueue_impl(stream, info.kernel, args, info.info);
|
||||
}
|
||||
|
||||
launch_context_t base::get_launch_context(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
std::pair<base*, rt::jit*> profile = get_profile_impl(stream, args, autotune);
|
||||
driver::kernel* kernel = profile.second->get_function(name_.c_str());
|
||||
rt::launch_information info = profile.second->get_launch_info(name_.c_str());
|
||||
return {profile.first, kernel, info};
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -13,26 +13,42 @@ bool dot::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const dot*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(M_, N_, K_)
|
||||
< std::tie(y->M_, y->N_, y->K_);
|
||||
return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_)
|
||||
< std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_);
|
||||
}
|
||||
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
params_t dot::heuristics() const {
|
||||
|
||||
throw std::runtime_error("not implemented");
|
||||
}
|
||||
|
||||
base * dot::clone() const {
|
||||
return new dot(*this);
|
||||
}
|
||||
|
||||
dot::dot(int32_t M, int32_t N, int32_t K):
|
||||
base("bsdot"), M_(M), N_(N), K_(K) {
|
||||
ab_ty_ = "fp32";
|
||||
c_ty_ = "fp32";
|
||||
dot::dot(int32_t N, int32_t K, int32_t S, int32_t C,
|
||||
const std::string& ty, int32_t BS, int32_t nlocks):
|
||||
base("bsdot"),
|
||||
N_(N), K_(K), S_(S), C_(C),
|
||||
ab_ty_(ty), c_ty_(ty),
|
||||
BS_(BS), nlocks_(nlocks) {
|
||||
}
|
||||
|
||||
void dot::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||
// int32_t TM = info.globals["TM"];
|
||||
// size_t grid_0 = (N_ + TM - 1) / TM;
|
||||
// if(nlocks_){
|
||||
// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4);
|
||||
// ((driver::cu_buffer*)locks_)->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||
// }
|
||||
}
|
||||
|
||||
void dot::deinit_impl() {
|
||||
// if(locks_)
|
||||
// delete locks_;
|
||||
}
|
||||
|
||||
void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
@@ -41,64 +57,89 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *b = args[1];
|
||||
driver::buffer *c = args[2];
|
||||
driver::buffer *lut = args[3];
|
||||
int32_t lda = M_;
|
||||
int32_t ldc = M_;
|
||||
driver::buffer *locks = args[4];
|
||||
int32_t lda = N_;
|
||||
int32_t ldc = N_;
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, lda);
|
||||
kernel->setArg(4, ldc);
|
||||
kernel->setArg(5, lut);
|
||||
kernel->setArg(5, N_);
|
||||
kernel->setArg(6, lut);
|
||||
kernel->setArg(7, locks);
|
||||
kernel->setArg(8, nlocks_);
|
||||
int32_t TM = info.globals["TM"];
|
||||
int32_t TN = info.globals["TN"];
|
||||
size_t grid_0 = (M_ + TM - 1) / TM;
|
||||
size_t grid_1 = (N_ + TN - 1) / TN;
|
||||
size_t grid_0 = (N_ + TM - 1) / TM;
|
||||
size_t grid_1 = S_;
|
||||
std::cout << N_ << " " << grid_0 << std::endl;
|
||||
if(nlocks_){
|
||||
// locks_ = triton::driver::buffer::create(stream->context(), grid_0 * nlocks_ * 2 * 4);
|
||||
((driver::cu_buffer*)locks)->set_zero(stream, grid_0 * nlocks_ * 2 * 4);
|
||||
}
|
||||
stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1});
|
||||
stream->synchronize();
|
||||
}
|
||||
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string result =
|
||||
|
||||
R"(
|
||||
const tunable int32 TM = {64, 128};
|
||||
const tunable int32 TN = {32};
|
||||
const tunable int32 TK = {32};
|
||||
const tunable int32 TM = {64};
|
||||
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
|
||||
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
fp32* C,
|
||||
int32 lda, int32 ldc,
|
||||
int32* lut_base){
|
||||
)" + c_ty_ + R"(* C,
|
||||
int32 lda, int32 ldc, int32 N,
|
||||
int32* lut, int32* locks, int32 nlocks){
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32 c[TM, TN] = 0;
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryb[TN] = 0 ... TN;
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
int32 offa[TM, TK] = rxa[:, newaxis] + rka[newaxis, :]*lda;
|
||||
int32 offb[TK, TN] = ryb[newaxis, :] + rkb[:, newaxis]*TK;
|
||||
int32 *header = lut_base + ridy * 4;
|
||||
int32 *header = lut + ridy * 4;
|
||||
int32 offset = *(header + 0);
|
||||
int32 K = *(header + 1);
|
||||
int32 h2 = *(header + 2);
|
||||
int32 h3 = *(header + 3);
|
||||
int32 *lut = lut_base + offset*2;
|
||||
int32 column = *(header + 2);
|
||||
int32 lockid = *(header + 3);
|
||||
int32 *plut = lut + offset * 2;
|
||||
for(int32 k = K; k > 0; k = k - 1){
|
||||
int32 ak = *(lut + 0);
|
||||
int32 bk = *(lut + 1);
|
||||
fp32* pa[TM, TK] = A + offa + ak * TK * lda;
|
||||
fp32* pb[TK, TN] = B + offb + bk * TK * TN;
|
||||
fp32 a[TM, TK] = *pa;
|
||||
fp32 b[TK, TN] = *pb;;
|
||||
c = dot(a, b, c);
|
||||
lut = lut + 2;
|
||||
int32 ak = *(plut + 0);
|
||||
int32 bk = *(plut + 1);
|
||||
)" + ab_ty_ + R"(* pa[TM, TK] = A + offa + ak * TK * lda;
|
||||
)" + ab_ty_ + R"(* pb[TK, TN] = B + offb + bk * TK * TN;
|
||||
)" + ab_ty_ + R"( a[TM, TK] = *pa;
|
||||
)" + ab_ty_ + R"( b[TK, TN] = *pb;
|
||||
acc = dot(a, b, acc);
|
||||
plut = plut + 2;
|
||||
}
|
||||
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy * TN + (0 ... TN);
|
||||
fp32* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
*pc = c;
|
||||
int32 ryc[TN] = column * TN + (0 ... TN);
|
||||
)" + c_ty_ + R"(" c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
int1 checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0){
|
||||
@checkc *pc = c;
|
||||
}
|
||||
else{
|
||||
int32 *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int32 *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 count = *pcount;
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + *pc;
|
||||
}
|
||||
*pcount = 1;
|
||||
__atomic_exch(plock, 0);
|
||||
}
|
||||
})";
|
||||
|
||||
os << result;
|
||||
|
@@ -294,10 +294,18 @@ value *builder::create_get_range_id(unsigned axis, const std::string &name) {
|
||||
return insert(get_range_id_inst::create(ctx_, axis, name));
|
||||
}
|
||||
|
||||
value *builder::create_get_num_program(unsigned axis, const std::string &name) {
|
||||
return insert(get_num_program_inst::create(ctx_, axis, name));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){
|
||||
return insert(atomic_cas_inst::create(ptr, cmp, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_exch(value *ptr, value *val, const std::string &name){
|
||||
return insert(atomic_exch_inst::create(ptr, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_atomic_add(value *ptr, value *val, const std::string &name){
|
||||
return insert(atomic_add_inst::create(ptr, val, name));
|
||||
}
|
||||
|
@@ -636,6 +636,17 @@ instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::s
|
||||
return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// get_num_program
|
||||
get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_num_program_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
@@ -649,6 +660,18 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
||||
return new atomic_cas_inst(ptr, cmp, val, name, next);
|
||||
}
|
||||
|
||||
// atomic exch
|
||||
|
||||
atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string &name, instruction *next) {
|
||||
return new atomic_exch_inst(ptr, val, name, next);
|
||||
}
|
||||
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
|
@@ -120,6 +120,11 @@ ir::value* get_range_id_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_get_range_id(axis_->value());
|
||||
}
|
||||
|
||||
// get_num_program
|
||||
ir::value* get_num_program_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_get_num_program(axis_->value());
|
||||
}
|
||||
|
||||
// atomic cas
|
||||
ir::value* atomic_cas_expression::codegen(ir::module *mod) const {
|
||||
ir::value *ptr = ptr_->codegen(mod);
|
||||
@@ -128,6 +133,13 @@ ir::value* atomic_cas_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
|
||||
}
|
||||
|
||||
// atomic exch
|
||||
ir::value* atomic_exch_expression::codegen(ir::module *mod) const {
|
||||
ir::value *ptr = ptr_->codegen(mod);
|
||||
ir::value *val = val_->codegen(mod);
|
||||
return mod->get_builder().create_atomic_exch(ptr, val);
|
||||
}
|
||||
|
||||
// atomic add
|
||||
ir::value* atomic_add_expression::codegen(ir::module *mod) const {
|
||||
ir::value *ptr = ptr_->codegen(mod);
|
||||
|
Reference in New Issue
Block a user