[dnn/batchnorm]: added some more code in Triton-C batchnorm implementations

This commit is contained in:
Philippe Tillet
2019-07-08 20:18:20 -07:00
parent fa3270dcf2
commit f74dcb7e30
13 changed files with 103 additions and 11 deletions

View File

@@ -84,6 +84,21 @@ def run_shift():
b: hb})[0]
#print(result)
def batch_norm(x, g, b, epsilon=1e-6):
shape = x.shape
C = int(shape[1])
assert g.get_shape().num_elements() == C
assert b.get_shape().num_elements() == C
return module.batchnorm_forward(x, g, b, eps=epsilon)
@ops.RegisterGradient("BatchnormForward")
def batch_norm_grad(op, dy, mean, var):
eps = op.get_attr("eps")
return module.batchnorm_backward(dy, op.inputs[0], op.inputs[1],
op.outputs[1], op.outputs[2], eps=eps)
def run_batchnorm():
C, H, W, B = 32, 16, 16, 16
np.random.seed(0)
@@ -101,7 +116,7 @@ def run_batchnorm():
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
print(hx.sum(axis=(1,2,3)))
print(result[1])
run_batchnorm()

View File

@@ -53,6 +53,9 @@ private:
int32_t W_;
int32_t B_;
std::string ty_;
float eps_;
int32_t DHWB_;
float rcpDHWB_;
};
class batchnorm_backward {

View File

@@ -131,6 +131,7 @@ public:
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics

View File

@@ -581,6 +581,14 @@ public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
};
class sqrt_inst: public builtin_inst {
private:
sqrt_inst(value *arg, const std::string& name, instruction* next);
std::string repr_impl() const { return "sqrt"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
};
class reduce_inst: public builtin_inst {
private:
reduce_inst(value* arg, const std::string& name, instruction* next);

View File

@@ -167,6 +167,15 @@ private:
node* arg_;
};
class sqrt_expression: public builtin_expression{
public:
sqrt_expression(node *arg): arg_(arg) {}
ir::value* codegen(ir::module *) const;
private:
node* arg_;
};
class reduce_expression: public builtin_expression{
public:
reduce_expression(node *arg): arg_(arg) {}

View File

@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
%token IF ELSE FOR CONTINUE WHILE
%token NEWAXIS ELLIPSIS AT
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
%start translation_unit
%%
@@ -123,6 +123,7 @@ builtin_expression
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); }
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
| REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);}

View File

@@ -50,6 +50,7 @@ using triton::lang::return_void;
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
"__sum" { return return_impl(REDUCE_SUM, yytext); }
"sqrt" { return return_impl(SQRT, yytext); }
"dot" { return return_impl(DOT, yytext); }
"max" { return return_impl(MAX, yytext); }
"min" { return return_impl(MIN, yytext); }

View File

@@ -380,6 +380,12 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
return (Instruction*)res;
}
if(ir::sqrt_inst* ii = dynamic_cast<ir::sqrt_inst*>(inst)){
Value *val = value(ii->get_operand(0));
Value *sqrt = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::sqrt, {val->getType()});
Value *res = builder.CreateCall(sqrt, {val});
return (Instruction*)res;
}
// unknown instruction
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
}
@@ -797,7 +803,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn);
Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32));
Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)),
partial_reduce_do, partial_reduce_done);
builder.SetInsertPoint(partial_reduce_do);

View File

@@ -30,7 +30,10 @@ namespace dnn{
* --------------- */
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty) { }
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) {
DHWB_ = D_*H_*W_*B_;
rcpDHWB_ = (float)1 / DHWB_;
}
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *y, driver::buffer *m, driver::buffer *v,
@@ -44,7 +47,9 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, x);
kernel->setArg(4, g);
kernel->setArg(5, b);
kernel->setArg(6, (int32_t)(D_*H_*W_*B_));
kernel->setArg(6, DHWB_);
kernel->setArg(7, rcpDHWB_);
kernel->setArg(8, eps_);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
@@ -57,7 +62,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
restrict read_only fp32 *X,
restrict read_only fp32 *G,
restrict read_only fp32 *B,
int32 DHWN) {
int32 DHWN,
fp32 rcpDHWN, fp32 eps) {
int32 rx[TM] = 0 ... TM;
fp32 *px[TM];
fp32 x[TM];
@@ -72,9 +78,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
mean = mean + x;
px = px + TM;
}
fp32 m = __sum(mean);
fp32 *pm = M + c;
*pm = m;
*pm = __sum(mean) * rcpDHWN;
fp32 var[TM] = 0;
px = X + rx + c*DHWN;
@@ -84,9 +89,21 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
var = var + x*x;
px = px + TM;
}
fp32 v = __sum(var);
fp32 v = __sum(var) * rcpDHWN;
fp32 *pv = V + c;
*pv = v;
fp32 rstdg = 1 / sqrt(v + eps) * g;
px = X + rx + c*DHWN;
fp32* py[TM] = Y + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){
x = *px;
fp32 y[TM] = (x - mean)*rstdg + b;
*py = y;
px = px + TM;
py = py + TM;
}
})";
}
@@ -148,16 +165,25 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
fp32 dy[TM] = *pdy;
dg = dg + dy*(x - mean)*rstd;
db = db + dy;
px = px + TM;
pdy = pdy + TM;
}
fp32 sdg = __sum(dg);
fp32 sdb = __sum(db);
px = X + rx + offset;
pdy = DY + rx + offset;
pdx = DX + rx + offset;
for(int32 i = 0; i < DHWN; i += TM){
fp32 x[TM] = *px;
fp32 dy[TM] = *pdy;
fp32 xhat[TM] = (x - mean) * rstd;
fp32 xtmp[TM] = (xhat * dg + db) * NDHW;
fp32 dx[TM] = (dy - xtmp) * rstd * g;
*pdx = dx;
px = px + TM;
pdy = pdy + TM;
pdx = pdx + TM;
}
})";
}

View File

@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
std::cout << source << std::endl;
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -320,6 +320,10 @@ value *builder::create_trans(value *A, const std::string &name) {
return insert(trans_inst::create(A, name));
}
value *builder::create_sqrt(value *A, const std::string &name) {
return insert(sqrt_inst::create(A, name));
}
value *builder::create_reduce(value *A, const std::string &name) {
return insert(reduce_inst::create(A, name));
}

View File

@@ -566,6 +566,19 @@ instruction* trans_inst::create(value *arg, const std::string &name, instruction
return new trans_inst(arg, name, next);
}
//===----------------------------------------------------------------------===//
// sqrt instructions
//===----------------------------------------------------------------------===//
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
: builtin_inst(arg->get_type(), 1, 1, name, next){
set_operand(0, arg);
}
instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) {
return new sqrt_inst(arg, name, next);
}
//===----------------------------------------------------------------------===//
// reduce instructions
//===----------------------------------------------------------------------===//

View File

@@ -180,6 +180,12 @@ ir::value* trans_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_trans(arg_->codegen(mod));
}
// sqrt
ir::value* sqrt_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_sqrt(arg_->codegen(mod));
}
// reduce
ir::value* reduce_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_reduce(arg_->codegen(mod));