[codegen][selection] more flexible instruction selection for reduce_inst
This commit is contained in:
@@ -32,6 +32,7 @@ typedef std::vector<llvm::Value*> indices_t;
|
||||
struct distributed_axis {
|
||||
size_t contiguous;
|
||||
std::vector<llvm::Value*> values;
|
||||
llvm::Value* thread_id;
|
||||
};
|
||||
|
||||
class tile {
|
||||
|
@@ -134,7 +134,7 @@ public:
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::string &name = "");
|
||||
value *create_sqrt(value *A, const std::string &name = "");
|
||||
value *create_reduce(value *A, const std::string &name = "");
|
||||
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
|
@@ -605,11 +605,18 @@ public:
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
private:
|
||||
reduce_inst(value* arg, const std::string& name, instruction* next);
|
||||
static type* get_type(value *arg, unsigned axis);
|
||||
|
||||
private:
|
||||
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "reduce"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class select_inst: public builtin_inst {
|
||||
|
@@ -134,6 +134,16 @@ private:
|
||||
const expression *C_;
|
||||
};
|
||||
|
||||
class reshape_expression: public builtin_expression{
|
||||
public:
|
||||
reshape_expression(node *arg, node *shapes): arg_(arg), shapes_((list<expression*>*)shapes) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
const node *arg_;
|
||||
const list<expression*>* shapes_;
|
||||
};
|
||||
|
||||
class max_expression: public builtin_expression{
|
||||
public:
|
||||
max_expression(node* x, node* y)
|
||||
@@ -188,11 +198,12 @@ private:
|
||||
|
||||
class reduce_expression: public builtin_expression{
|
||||
public:
|
||||
reduce_expression(node *arg): arg_(arg) {}
|
||||
reduce_expression(node *arg, node *axis): arg_(arg), axis_((constant*)axis) {}
|
||||
ir::value* codegen(ir::module *mod) const;
|
||||
|
||||
private:
|
||||
node* arg_;
|
||||
constant* axis_;
|
||||
};
|
||||
|
||||
class indexing_expression: public postfix_expression{
|
||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||
%token IF ELSE FOR CONTINUE WHILE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST
|
||||
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -126,13 +126,14 @@ builtin_expression
|
||||
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
||||
| REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);}
|
||||
| REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);}
|
||||
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
|
||||
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
||||
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
||||
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
||||
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
|
||||
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
|
||||
| RESHAPE '(' expression ',' primary_expression_list ')' { $$ = new reshape_expression($3, $5); }
|
||||
;
|
||||
|
||||
/* Primary */
|
||||
|
@@ -30,18 +30,18 @@ using triton::lang::return_void;
|
||||
"for" { return return_impl(FOR, yytext); }
|
||||
"while" { return return_impl(WHILE, yytext); }
|
||||
"void" { return return_impl(VOID, yytext); }
|
||||
"uchar" { return return_impl(UCHAR, yytext); }
|
||||
"ushort" { return return_impl(USHORT, yytext); }
|
||||
"uint" { return return_impl(UINT, yytext); }
|
||||
"ulong" { return return_impl(ULONG, yytext); }
|
||||
"bool" { return return_impl(BOOL, yytext); }
|
||||
"char" { return return_impl(CHAR, yytext); }
|
||||
"short" { return return_impl(SHORT, yytext); }
|
||||
"int" { return return_impl(INT, yytext); }
|
||||
"long" { return return_impl(LONG, yytext); }
|
||||
"half" { return return_impl(HALF, yytext); }
|
||||
"float" { return return_impl(FLOAT, yytext); }
|
||||
"double" { return return_impl(DOUBLE, yytext); }
|
||||
"uchar" { return return_impl(UINT8, yytext); }
|
||||
"ushort" { return return_impl(UINT16, yytext); }
|
||||
"uint" { return return_impl(UINT32, yytext); }
|
||||
"ulong" { return return_impl(UINT64, yytext); }
|
||||
"bool" { return return_impl(INT1, yytext); }
|
||||
"char" { return return_impl(INT8, yytext); }
|
||||
"short" { return return_impl(INT16, yytext); }
|
||||
"int" { return return_impl(INT32, yytext); }
|
||||
"long" { return return_impl(INT64, yytext); }
|
||||
"half" { return return_impl(FP16, yytext); }
|
||||
"float" { return return_impl(FP32, yytext); }
|
||||
"double" { return return_impl(FP64, yytext); }
|
||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
||||
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
|
||||
@@ -49,6 +49,7 @@ using triton::lang::return_void;
|
||||
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }
|
||||
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
|
||||
"__sum" { return return_impl(REDUCE_SUM, yytext); }
|
||||
"__reshape" { return return_impl(RESHAPE, yytext); }
|
||||
"sqrt" { return return_impl(SQRT, yytext); }
|
||||
"dot" { return return_impl(DOT, yytext); }
|
||||
"max" { return return_impl(MAX, yytext); }
|
||||
|
@@ -80,9 +80,10 @@ indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
|
||||
for(unsigned i = 0; i < ordered_indices_.size(); i++)
|
||||
for(unsigned i = 0; i < ordered_indices_.size(); i++){
|
||||
if(i % vector_size_ == 0)
|
||||
fn(ordered_indices_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/* Shared Tile */
|
||||
@@ -498,15 +499,15 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
Value *warp_size_k = builder.getInt32(warp_size[k]);
|
||||
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
||||
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
||||
thread_id = builder.CreateMul(thread_id, contiguous_k);
|
||||
Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k);
|
||||
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
||||
unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
|
||||
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
||||
}
|
||||
}
|
||||
else {
|
||||
@@ -671,7 +672,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
shapes[0] += pad;
|
||||
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
|
||||
// create shared tile
|
||||
if(buffer_info_->is_shared(v)){
|
||||
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v)){
|
||||
// shared copy
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
|
||||
// phi-node (double-buffering)
|
||||
@@ -825,88 +826,72 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::reduce_inst*>(ins)){
|
||||
std::map<indices_t, Value*> partial;
|
||||
distributed_tile* op = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
size_t axis = 0;
|
||||
unsigned num_warps = params_->get_num_threads() / 32;
|
||||
std::vector<unsigned> shapes = op->get_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
if(shapes.empty())
|
||||
shapes.push_back(1);
|
||||
ir::value *op = ins->get_operand(0);
|
||||
distributed_tile* op_tile = (distributed_tile*)tmap_.at(op);
|
||||
unsigned axis = x->get_axis();
|
||||
|
||||
// reduce within thread
|
||||
op->for_each([&](indices_t idx){
|
||||
op_tile->for_each([&](indices_t idx) {
|
||||
indices_t pidx = idx;
|
||||
pidx.erase(pidx.begin() + axis);
|
||||
if(pidx.empty())
|
||||
pidx.push_back(builder.getInt32(0));
|
||||
Value *current = op->get_value(idx);
|
||||
Value *current = op_tile->get_value(idx);
|
||||
// current partial result is not initialized -- create
|
||||
if(partial.find(pidx) == partial.end())
|
||||
partial[pidx] = current;
|
||||
// current partial result is initialized -- accumulate
|
||||
else
|
||||
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
|
||||
});
|
||||
|
||||
// reduce within warp
|
||||
Value *shfl = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_shfl_sync_bfly_f32);
|
||||
for (int i = 16; i > 0; i >>= 1)
|
||||
for(auto& x: partial)
|
||||
{
|
||||
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), x.second,
|
||||
builder.getInt32(i),
|
||||
builder.getInt32(0x1f)});
|
||||
x.second = builder.CreateFAdd(x.second, rhs);
|
||||
}
|
||||
|
||||
// reduce within block
|
||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||
BasicBlock *partial_reduce_do = BasicBlock::Create(ctx, "partial_reduce_do", fn);
|
||||
BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn);
|
||||
Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32));
|
||||
Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
|
||||
builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)),
|
||||
partial_reduce_do, partial_reduce_done);
|
||||
builder.SetInsertPoint(partial_reduce_do);
|
||||
// reduce within blocks
|
||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||
Type *ptr_ty = PointerType::get(builder.getFloatTy(), addr_space);
|
||||
Value *sh_mem_ptr = builder.CreateBitCast(sh_mem_ptr_, ptr_ty);
|
||||
for(auto& x: partial){
|
||||
Value *offset = shared_tile::shared_offset(builder, shapes, x.first);
|
||||
offset = builder.CreateAdd(offset, builder.CreateMul(warp_id, builder.getInt32(shapes[0])));
|
||||
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, offset);
|
||||
builder.CreateStore(x.second, write_ptr);
|
||||
}
|
||||
builder.CreateBr(partial_reduce_done);
|
||||
builder.SetInsertPoint(partial_reduce_done);
|
||||
Type *res_ty = builder.getFloatTy();
|
||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
unsigned depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
|
||||
for(auto& x: partial) {
|
||||
// current element being computed
|
||||
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
|
||||
Value *&result = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx.insert(write_idx.begin() + axis, lane);
|
||||
// shared memory write pointer
|
||||
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
|
||||
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
|
||||
// initialize shared memory
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
current[axis] = builder.getInt32(i);
|
||||
// shared memory offset
|
||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current);
|
||||
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
|
||||
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
|
||||
// shared memory read pointer
|
||||
Value *read_ptr = builder.CreateGEP(write_ptr, read_offset);
|
||||
tgt_->add_barrier(module, builder);
|
||||
Value *next = builder.CreateLoad(read_ptr);
|
||||
// accumulate
|
||||
result = builder.CreateFAdd(result, next);
|
||||
// write back
|
||||
builder.CreateStore(result, write_ptr);
|
||||
}
|
||||
|
||||
// Final reduction with the first warp
|
||||
tgt_->add_barrier(module, builder);
|
||||
BasicBlock *final_reduce_do = BasicBlock::Create(ctx, "final_reduce_do", fn);
|
||||
BasicBlock *final_reduce_done = BasicBlock::Create(ctx, "final_reduce_done", fn);
|
||||
builder.CreateCondBr(builder.CreateICmpEQ(warp_id, builder.getInt32(0)),
|
||||
final_reduce_do, final_reduce_done);
|
||||
builder.SetInsertPoint(final_reduce_do);
|
||||
Value *read_ptr = builder.CreateGEP(sh_mem_ptr, tid);
|
||||
BasicBlock *read_shmem_do = BasicBlock::Create(ctx, "read_shmem_do", fn);
|
||||
BasicBlock *read_shmem_done = BasicBlock::Create(ctx, "read_shmem_done", fn);
|
||||
builder.CreateCondBr(builder.CreateICmpULT(id_in_warp, builder.getInt32(num_warps)),
|
||||
read_shmem_do, read_shmem_done);
|
||||
builder.SetInsertPoint(read_shmem_do);
|
||||
Value *loaded= builder.CreateLoad(read_ptr);
|
||||
builder.CreateBr(read_shmem_done);
|
||||
builder.SetInsertPoint(read_shmem_done);
|
||||
Value *result = builder.CreatePHI(loaded->getType(), 2);
|
||||
((PHINode*)result)->addIncoming(ConstantFP::get(loaded->getType(), (double)0), final_reduce_do);
|
||||
((PHINode*)result)->addIncoming(loaded, read_shmem_do);
|
||||
for (int i = params_->get_num_threads() / 64; i > 0; i >>= 1){
|
||||
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), result,
|
||||
builder.getInt32(i), builder.getInt32(0x1f)});
|
||||
result = builder.CreateFAdd(result, rhs);
|
||||
// result is on the first lane of shared memory
|
||||
indices_t final = write_idx;
|
||||
final[axis] = builder.getInt32(0);
|
||||
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final);
|
||||
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
|
||||
tgt_->add_barrier(module, builder);
|
||||
result = builder.CreateLoad(read_ptr);
|
||||
if(tmap_.find(ins) == tmap_.end())
|
||||
vmap_[ins] = result;
|
||||
else{
|
||||
distributed_tile *ti = (distributed_tile*)tmap_[ins];
|
||||
ti->set_value(x.first, result);
|
||||
}
|
||||
}
|
||||
builder.CreateStore(result, read_ptr);
|
||||
builder.CreateBr(final_reduce_done);
|
||||
builder.SetInsertPoint(final_reduce_done);
|
||||
tgt_->add_barrier(module, builder);
|
||||
vmap_[ins] = builder.CreateLoad(sh_mem_ptr);
|
||||
return;
|
||||
}
|
||||
tile *ti = tmap_[ins];
|
||||
|
@@ -43,15 +43,16 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
|
||||
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
if(dynamic_cast<ir::reduce_inst*>(x)){
|
||||
size_t shape = 1;
|
||||
if(x->get_type()->is_tile_ty()){
|
||||
auto shapes = x->get_type()->get_tile_shapes();
|
||||
for(auto x: shapes)
|
||||
shape *= x->get_value();
|
||||
}
|
||||
size_t n_warps = params_->get_num_threads() / 32;
|
||||
return shape * num_bytes * n_warps;
|
||||
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
|
||||
size_t axis = red->get_axis();
|
||||
ir::value *op = red->get_operand(0);
|
||||
auto shapes = op->get_type()->get_tile_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
size_t num_elements = 1;
|
||||
for(auto x: shapes)
|
||||
num_elements *= x->get_value();
|
||||
size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned pad = is_ld_padded(x);
|
||||
if(pad > 0){
|
||||
|
@@ -58,8 +58,19 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
||||
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
|
||||
return;
|
||||
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v))
|
||||
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
|
||||
unsigned axis = reduce->get_axis();
|
||||
ir::value *arg = reduce->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned i = 0; i < in_shapes.size(); i++){
|
||||
if(i == axis)
|
||||
continue;
|
||||
// std::cout << arg->get_name() << " " << v->get_name() << std::endl;
|
||||
add_constraint({reduce, current++}, {arg, i});
|
||||
}
|
||||
return;
|
||||
}
|
||||
else
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
// Reshape
|
||||
@@ -74,8 +85,10 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
static_params_.insert({{v, i}, 1});
|
||||
else if(!is_skewed && is_same)
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else
|
||||
else{
|
||||
is_skewed = true;
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Splat
|
||||
@@ -137,6 +150,7 @@ tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
}
|
||||
|
||||
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
// std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
|
||||
groups_[x.first].insert({x.second, group_id});
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
@@ -190,6 +204,7 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
|
||||
}
|
||||
|
||||
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
|
||||
// std::cout << "group? " << value->get_name() << " " << ax << std::endl;
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
}
|
||||
|
@@ -71,7 +71,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
|
||||
void batchnorm_forward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int TM = {32, 64, 128};
|
||||
const tunable int TM = {128};
|
||||
|
||||
void batchnorm_forward(float *Y, float *M, float *V,
|
||||
restrict read_only float *X,
|
||||
@@ -94,7 +94,7 @@ void batchnorm_forward(float *Y, float *M, float *V,
|
||||
px = px + TM;
|
||||
}
|
||||
float *pm = M + c;
|
||||
float m = __sum(mean) * rcpDHWN;
|
||||
float m = __sum(mean, 0) * rcpDHWN;
|
||||
*pm = m;
|
||||
|
||||
float var[TM] = 0;
|
||||
@@ -105,7 +105,7 @@ void batchnorm_forward(float *Y, float *M, float *V,
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
float v = __sum(var) * rcpDHWN;
|
||||
float v = __sum(var, 0) * rcpDHWN;
|
||||
float *pv = V + c;
|
||||
*pv = v;
|
||||
float rstdg = 1 / sqrt(v + eps) * g;
|
||||
@@ -167,7 +167,7 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke
|
||||
void batchnorm_backward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int TM = {32, 64, 128};
|
||||
const tunable int TM = {128};
|
||||
|
||||
void batchnorm_backward(float *DX, float *DG, float *DB,
|
||||
restrict read_only float *DY,
|
||||
@@ -199,8 +199,8 @@ void batchnorm_backward(float *DX, float *DG, float *DB,
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
}
|
||||
float sdg = __sum(dg);
|
||||
float sdb = __sum(db);
|
||||
float sdg = __sum(dg, 0);
|
||||
float sdb = __sum(db, 0);
|
||||
float *pdg = DG + c;
|
||||
float *pdb = DB + c;
|
||||
*pdg = sdg;
|
||||
|
@@ -322,8 +322,8 @@ value *builder::create_sqrt(value *A, const std::string &name) {
|
||||
return insert(sqrt_inst::create(A, name));
|
||||
}
|
||||
|
||||
value *builder::create_reduce(value *A, const std::string &name) {
|
||||
return insert(reduce_inst::create(A, name));
|
||||
value *builder::create_reduce(value *A, unsigned axis, const std::string &name) {
|
||||
return insert(reduce_inst::create(A, axis, name));
|
||||
}
|
||||
|
||||
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){
|
||||
|
@@ -597,13 +597,24 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
// reduce instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
reduce_inst::reduce_inst(value *arg, const std::string &name, instruction *next)
|
||||
: builtin_inst(arg->get_type()->get_scalar_ty(), 1, 1, name, next) {
|
||||
type* reduce_inst::get_type(value *arg, unsigned axis) {
|
||||
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
|
||||
shapes.erase(shapes.begin() + axis);
|
||||
type *scalar_ty = arg->get_type()->get_scalar_ty();
|
||||
if(shapes.size() == 0)
|
||||
return scalar_ty;
|
||||
else
|
||||
return tile_type::get(scalar_ty, shapes);
|
||||
}
|
||||
|
||||
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_type(arg, axis), 1, 1, name, next),
|
||||
axis_(axis){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* reduce_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new reduce_inst(arg, name, next);
|
||||
instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new reduce_inst(arg, axis, name, next);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -161,6 +161,21 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_dot(A, B, C);
|
||||
}
|
||||
|
||||
// reshape
|
||||
ir::value* reshape_expression::codegen(ir::module *mod) const {
|
||||
// arg
|
||||
ir::value *arg = arg_->codegen(mod);
|
||||
// shapes
|
||||
ir::type::tile_shapes_t shapes;
|
||||
for(expression *expr: shapes_->values()){
|
||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||
assert(shape);
|
||||
shapes.push_back(shape);
|
||||
}
|
||||
// return
|
||||
return mod->get_builder().create_reshape(arg, shapes);
|
||||
}
|
||||
|
||||
// min
|
||||
ir::value* min_expression::codegen(ir::module *mod) const {
|
||||
ir::value* cmp = binary_expression(LT, (node*)x_, (node*)y_).codegen(mod);
|
||||
@@ -198,7 +213,7 @@ ir::value* sqrt_expression::codegen(ir::module *mod) const {
|
||||
|
||||
// reduce
|
||||
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_reduce(arg_->codegen(mod));
|
||||
return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value());
|
||||
}
|
||||
|
||||
/* Postfix expression */
|
||||
|
@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// thread pools
|
||||
ThreadPool pool(nthreads);
|
||||
// ThreadPool pool(nthreads);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
// Execute function
|
||||
pool.enqueue(f,values);
|
||||
// f(values);
|
||||
// pool.enqueue(f,values);
|
||||
f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
|
Reference in New Issue
Block a user