[codegen][selection] more flexible instruction selection for reduce_inst

This commit is contained in:
Philippe Tillet
2019-08-04 16:34:36 -07:00
parent 6be532c6a2
commit d869d9a924
14 changed files with 167 additions and 119 deletions

View File

@@ -32,6 +32,7 @@ typedef std::vector<llvm::Value*> indices_t;
struct distributed_axis {
size_t contiguous;
std::vector<llvm::Value*> values;
llvm::Value* thread_id;
};
class tile {

View File

@@ -134,7 +134,7 @@ public:
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, const std::string &name = "");
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");

View File

@@ -605,11 +605,18 @@ public:
class reduce_inst: public builtin_inst {
private:
reduce_inst(value* arg, const std::string& name, instruction* next);
static type* get_type(value *arg, unsigned axis);
private:
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
private:
unsigned axis_;
};
class select_inst: public builtin_inst {

View File

@@ -134,6 +134,16 @@ private:
const expression *C_;
};
class reshape_expression: public builtin_expression{
public:
reshape_expression(node *arg, node *shapes): arg_(arg), shapes_((list<expression*>*)shapes) { }
ir::value* codegen(ir::module *) const;
private:
const node *arg_;
const list<expression*>* shapes_;
};
class max_expression: public builtin_expression{
public:
max_expression(node* x, node* y)
@@ -188,11 +198,12 @@ private:
class reduce_expression: public builtin_expression{
public:
reduce_expression(node *arg): arg_(arg) {}
reduce_expression(node *arg, node *axis): arg_(arg), axis_((constant*)axis) {}
ir::value* codegen(ir::module *mod) const;
private:
node* arg_;
constant* axis_;
};
class indexing_expression: public postfix_expression{

View File

@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
%token IF ELSE FOR CONTINUE WHILE
%token NEWAXIS ELLIPSIS AT
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
%start translation_unit
%%
@@ -126,13 +126,14 @@ builtin_expression
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
| REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);}
| REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);}
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
| ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); }
| ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); }
| RESHAPE '(' expression ',' primary_expression_list ')' { $$ = new reshape_expression($3, $5); }
;
/* Primary */

View File

@@ -30,18 +30,18 @@ using triton::lang::return_void;
"for" { return return_impl(FOR, yytext); }
"while" { return return_impl(WHILE, yytext); }
"void" { return return_impl(VOID, yytext); }
"uchar" { return return_impl(UCHAR, yytext); }
"ushort" { return return_impl(USHORT, yytext); }
"uint" { return return_impl(UINT, yytext); }
"ulong" { return return_impl(ULONG, yytext); }
"bool" { return return_impl(BOOL, yytext); }
"char" { return return_impl(CHAR, yytext); }
"short" { return return_impl(SHORT, yytext); }
"int" { return return_impl(INT, yytext); }
"long" { return return_impl(LONG, yytext); }
"half" { return return_impl(HALF, yytext); }
"float" { return return_impl(FLOAT, yytext); }
"double" { return return_impl(DOUBLE, yytext); }
"uchar" { return return_impl(UINT8, yytext); }
"ushort" { return return_impl(UINT16, yytext); }
"uint" { return return_impl(UINT32, yytext); }
"ulong" { return return_impl(UINT64, yytext); }
"bool" { return return_impl(INT1, yytext); }
"char" { return return_impl(INT8, yytext); }
"short" { return return_impl(INT16, yytext); }
"int" { return return_impl(INT32, yytext); }
"long" { return return_impl(INT64, yytext); }
"half" { return return_impl(FP16, yytext); }
"float" { return return_impl(FP32, yytext); }
"double" { return return_impl(FP64, yytext); }
"..." { return return_impl(ELLIPSIS, yytext); }
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
@@ -49,6 +49,7 @@ using triton::lang::return_void;
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
"__sum" { return return_impl(REDUCE_SUM, yytext); }
"__reshape" { return return_impl(RESHAPE, yytext); }
"sqrt" { return return_impl(SQRT, yytext); }
"dot" { return return_impl(DOT, yytext); }
"max" { return return_impl(MAX, yytext); }

View File

@@ -80,9 +80,10 @@ indices_t distributed_tile::get_ordered_indices(unsigned id) {
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
for(unsigned i = 0; i < ordered_indices_.size(); i++)
for(unsigned i = 0; i < ordered_indices_.size(); i++){
if(i % vector_size_ == 0)
fn(ordered_indices_[i]);
}
}
/* Shared Tile */
@@ -498,15 +499,15 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
Value *warp_size_k = builder.getInt32(warp_size[k]);
Value *contiguous_k = builder.getInt32(contiguous[k]);
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
thread_id = builder.CreateMul(thread_id, contiguous_k);
Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k);
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
}
}
else {
@@ -671,7 +672,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
shapes[0] += pad;
Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx);
// create shared tile
if(buffer_info_->is_shared(v)){
if(buffer_info_->is_shared(v) && !dynamic_cast<ir::reduce_inst*>(v)){
// shared copy
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace());
// phi-node (double-buffering)
@@ -825,88 +826,72 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
}
if(auto *x = dynamic_cast<ir::reduce_inst*>(ins)){
std::map<indices_t, Value*> partial;
distributed_tile* op = (distributed_tile*)tmap_.at(ins->get_operand(0));
size_t axis = 0;
unsigned num_warps = params_->get_num_threads() / 32;
std::vector<unsigned> shapes = op->get_shapes();
shapes.erase(shapes.begin() + axis);
if(shapes.empty())
shapes.push_back(1);
ir::value *op = ins->get_operand(0);
distributed_tile* op_tile = (distributed_tile*)tmap_.at(op);
unsigned axis = x->get_axis();
// reduce within thread
op->for_each([&](indices_t idx){
op_tile->for_each([&](indices_t idx) {
indices_t pidx = idx;
pidx.erase(pidx.begin() + axis);
if(pidx.empty())
pidx.push_back(builder.getInt32(0));
Value *current = op->get_value(idx);
Value *current = op_tile->get_value(idx);
// current partial result is not initialized -- create
if(partial.find(pidx) == partial.end())
partial[pidx] = current;
// current partial result is initialized -- accumulate
else
partial[pidx] = builder.CreateFAdd(partial[pidx], current);
});
// reduce within warp
Value *shfl = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_shfl_sync_bfly_f32);
for (int i = 16; i > 0; i >>= 1)
for(auto& x: partial)
{
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), x.second,
builder.getInt32(i),
builder.getInt32(0x1f)});
x.second = builder.CreateFAdd(x.second, rhs);
}
// reduce within block
Value *tid = tgt_->get_local_id(module, builder, 0);
BasicBlock *partial_reduce_do = BasicBlock::Create(ctx, "partial_reduce_do", fn);
BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn);
Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32));
Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)),
partial_reduce_do, partial_reduce_done);
builder.SetInsertPoint(partial_reduce_do);
// reduce within blocks
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
Type *ptr_ty = PointerType::get(builder.getFloatTy(), addr_space);
Value *sh_mem_ptr = builder.CreateBitCast(sh_mem_ptr_, ptr_ty);
for(auto& x: partial){
Value *offset = shared_tile::shared_offset(builder, shapes, x.first);
offset = builder.CreateAdd(offset, builder.CreateMul(warp_id, builder.getInt32(shapes[0])));
Value *write_ptr = builder.CreateGEP(sh_mem_ptr, offset);
builder.CreateStore(x.second, write_ptr);
}
builder.CreateBr(partial_reduce_done);
builder.SetInsertPoint(partial_reduce_done);
Type *res_ty = builder.getFloatTy();
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
unsigned depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
for(auto& x: partial) {
// current element being computed
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
Value *&result = x.second;
indices_t write_idx = x.first;
write_idx.insert(write_idx.begin() + axis, lane);
// shared memory write pointer
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
// initialize shared memory
builder.CreateStore(result, write_ptr);
// build result
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));
current[axis] = builder.getInt32(i);
// shared memory offset
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current);
Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i));
read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0));
// shared memory read pointer
Value *read_ptr = builder.CreateGEP(write_ptr, read_offset);
tgt_->add_barrier(module, builder);
Value *next = builder.CreateLoad(read_ptr);
// accumulate
result = builder.CreateFAdd(result, next);
// write back
builder.CreateStore(result, write_ptr);
}
// Final reduction with the first warp
tgt_->add_barrier(module, builder);
BasicBlock *final_reduce_do = BasicBlock::Create(ctx, "final_reduce_do", fn);
BasicBlock *final_reduce_done = BasicBlock::Create(ctx, "final_reduce_done", fn);
builder.CreateCondBr(builder.CreateICmpEQ(warp_id, builder.getInt32(0)),
final_reduce_do, final_reduce_done);
builder.SetInsertPoint(final_reduce_do);
Value *read_ptr = builder.CreateGEP(sh_mem_ptr, tid);
BasicBlock *read_shmem_do = BasicBlock::Create(ctx, "read_shmem_do", fn);
BasicBlock *read_shmem_done = BasicBlock::Create(ctx, "read_shmem_done", fn);
builder.CreateCondBr(builder.CreateICmpULT(id_in_warp, builder.getInt32(num_warps)),
read_shmem_do, read_shmem_done);
builder.SetInsertPoint(read_shmem_do);
Value *loaded= builder.CreateLoad(read_ptr);
builder.CreateBr(read_shmem_done);
builder.SetInsertPoint(read_shmem_done);
Value *result = builder.CreatePHI(loaded->getType(), 2);
((PHINode*)result)->addIncoming(ConstantFP::get(loaded->getType(), (double)0), final_reduce_do);
((PHINode*)result)->addIncoming(loaded, read_shmem_do);
for (int i = params_->get_num_threads() / 64; i > 0; i >>= 1){
Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), result,
builder.getInt32(i), builder.getInt32(0x1f)});
result = builder.CreateFAdd(result, rhs);
// result is on the first lane of shared memory
indices_t final = write_idx;
final[axis] = builder.getInt32(0);
Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final);
Value *read_ptr = builder.CreateGEP(base_ptr, read_offset);
tgt_->add_barrier(module, builder);
result = builder.CreateLoad(read_ptr);
if(tmap_.find(ins) == tmap_.end())
vmap_[ins] = result;
else{
distributed_tile *ti = (distributed_tile*)tmap_[ins];
ti->set_value(x.first, result);
}
}
builder.CreateStore(result, read_ptr);
builder.CreateBr(final_reduce_done);
builder.SetInsertPoint(final_reduce_done);
tgt_->add_barrier(module, builder);
vmap_[ins] = builder.CreateLoad(sh_mem_ptr);
return;
}
tile *ti = tmap_[ins];

View File

@@ -43,15 +43,16 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
if(dynamic_cast<ir::reduce_inst*>(x)){
size_t shape = 1;
if(x->get_type()->is_tile_ty()){
auto shapes = x->get_type()->get_tile_shapes();
for(auto x: shapes)
shape *= x->get_value();
}
size_t n_warps = params_->get_num_threads() / 32;
return shape * num_bytes * n_warps;
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
size_t axis = red->get_axis();
ir::value *op = red->get_operand(0);
auto shapes = op->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
size_t num_elements = 1;
for(auto x: shapes)
num_elements *= x->get_value();
size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
return num_elements * num_bytes * depth;
}
unsigned pad = is_ld_padded(x);
if(pad > 0){

View File

@@ -58,8 +58,19 @@ void tune::init_c_graph(ir::instruction *v) {
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
else if(auto *downcast = dynamic_cast<ir::downcast_inst*>(v))
return;
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v))
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
unsigned axis = reduce->get_axis();
ir::value *arg = reduce->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned i = 0; i < in_shapes.size(); i++){
if(i == axis)
continue;
// std::cout << arg->get_name() << " " << v->get_name() << std::endl;
add_constraint({reduce, current++}, {arg, i});
}
return;
}
else
shapes = v->get_type()->get_tile_shapes();
// Reshape
@@ -74,8 +85,10 @@ void tune::init_c_graph(ir::instruction *v) {
static_params_.insert({{v, i}, 1});
else if(!is_skewed && is_same)
add_constraint({v, i}, {op, current++});
else
else{
is_skewed = true;
add_constraint({v, i}, {v, i});
}
}
}
// Splat
@@ -137,6 +150,7 @@ tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
}
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
// std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl;
groups_[x.first].insert({x.second, group_id});
if(nodes.find(x) != nodes.end()){
nodes.erase(x);
@@ -190,6 +204,7 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
}
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
// std::cout << "group? " << value->get_name() << " " << ax << std::endl;
unsigned result = groups_.at(value).at(ax);
return result;
}

View File

@@ -71,7 +71,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
void batchnorm_forward::triton_c_src(std::ostream &os) const {
os <<
R"(
const tunable int TM = {32, 64, 128};
const tunable int TM = {128};
void batchnorm_forward(float *Y, float *M, float *V,
restrict read_only float *X,
@@ -94,7 +94,7 @@ void batchnorm_forward(float *Y, float *M, float *V,
px = px + TM;
}
float *pm = M + c;
float m = __sum(mean) * rcpDHWN;
float m = __sum(mean, 0) * rcpDHWN;
*pm = m;
float var[TM] = 0;
@@ -105,7 +105,7 @@ void batchnorm_forward(float *Y, float *M, float *V,
var = var + x*x;
px = px + TM;
}
float v = __sum(var) * rcpDHWN;
float v = __sum(var, 0) * rcpDHWN;
float *pv = V + c;
*pv = v;
float rstdg = 1 / sqrt(v + eps) * g;
@@ -167,7 +167,7 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke
void batchnorm_backward::triton_c_src(std::ostream &os) const {
os <<
R"(
const tunable int TM = {32, 64, 128};
const tunable int TM = {128};
void batchnorm_backward(float *DX, float *DG, float *DB,
restrict read_only float *DY,
@@ -199,8 +199,8 @@ void batchnorm_backward(float *DX, float *DG, float *DB,
px = px + TM;
pdy = pdy + TM;
}
float sdg = __sum(dg);
float sdb = __sum(db);
float sdg = __sum(dg, 0);
float sdb = __sum(db, 0);
float *pdg = DG + c;
float *pdb = DB + c;
*pdg = sdg;

View File

@@ -322,8 +322,8 @@ value *builder::create_sqrt(value *A, const std::string &name) {
return insert(sqrt_inst::create(A, name));
}
value *builder::create_reduce(value *A, const std::string &name) {
return insert(reduce_inst::create(A, name));
value *builder::create_reduce(value *A, unsigned axis, const std::string &name) {
return insert(reduce_inst::create(A, axis, name));
}
value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){

View File

@@ -597,13 +597,24 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction
//===----------------------------------------------------------------------===//
// reduce instructions
//===----------------------------------------------------------------------===//
reduce_inst::reduce_inst(value *arg, const std::string &name, instruction *next)
: builtin_inst(arg->get_type()->get_scalar_ty(), 1, 1, name, next) {
type* reduce_inst::get_type(value *arg, unsigned axis) {
ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes();
shapes.erase(shapes.begin() + axis);
type *scalar_ty = arg->get_type()->get_scalar_ty();
if(shapes.size() == 0)
return scalar_ty;
else
return tile_type::get(scalar_ty, shapes);
}
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(get_type(arg, axis), 1, 1, name, next),
axis_(axis){
set_operand(0, arg);
}
instruction* reduce_inst::create(value *arg, const std::string &name, instruction *next) {
return new reduce_inst(arg, name, next);
instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) {
return new reduce_inst(arg, axis, name, next);
}

View File

@@ -161,6 +161,21 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_dot(A, B, C);
}
// reshape
ir::value* reshape_expression::codegen(ir::module *mod) const {
// arg
ir::value *arg = arg_->codegen(mod);
// shapes
ir::type::tile_shapes_t shapes;
for(expression *expr: shapes_->values()){
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
assert(shape);
shapes.push_back(shape);
}
// return
return mod->get_builder().create_reshape(arg, shapes);
}
// min
ir::value* min_expression::codegen(ir::module *mod) const {
ir::value* cmp = binary_expression(LT, (node*)x_, (node*)y_).codegen(mod);
@@ -198,7 +213,7 @@ ir::value* sqrt_expression::codegen(ir::module *mod) const {
// reduce
ir::value* reduce_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_reduce(arg_->codegen(mod));
return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value());
}
/* Postfix expression */

View File

@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// thread pools
ThreadPool pool(nthreads);
// ThreadPool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
while(true){
// Execute function
pool.enqueue(f,values);
// f(values);
// pool.enqueue(f,values);
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;