[dnn/batchnorm]: added some more code in Triton-C batchnorm implementations
This commit is contained in:
@@ -84,6 +84,21 @@ def run_shift():
|
||||
b: hb})[0]
|
||||
#print(result)
|
||||
|
||||
|
||||
def batch_norm(x, g, b, epsilon=1e-6):
|
||||
shape = x.shape
|
||||
C = int(shape[1])
|
||||
assert g.get_shape().num_elements() == C
|
||||
assert b.get_shape().num_elements() == C
|
||||
return module.batchnorm_forward(x, g, b, eps=epsilon)
|
||||
|
||||
@ops.RegisterGradient("BatchnormForward")
|
||||
def batch_norm_grad(op, dy, mean, var):
|
||||
eps = op.get_attr("eps")
|
||||
return module.batchnorm_backward(dy, op.inputs[0], op.inputs[1],
|
||||
op.outputs[1], op.outputs[2], eps=eps)
|
||||
|
||||
|
||||
def run_batchnorm():
|
||||
C, H, W, B = 32, 16, 16, 16
|
||||
np.random.seed(0)
|
||||
@@ -101,7 +116,7 @@ def run_batchnorm():
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||
print(hx.sum(axis=(1,2,3)))
|
||||
print(result[1])
|
||||
|
||||
|
||||
|
||||
run_batchnorm()
|
||||
|
@@ -53,6 +53,9 @@ private:
|
||||
int32_t W_;
|
||||
int32_t B_;
|
||||
std::string ty_;
|
||||
float eps_;
|
||||
int32_t DHWB_;
|
||||
float rcpDHWB_;
|
||||
};
|
||||
|
||||
class batchnorm_backward {
|
||||
|
@@ -131,6 +131,7 @@ public:
|
||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::string &name = "");
|
||||
value *create_sqrt(value *A, const std::string &name = "");
|
||||
value *create_reduce(value *A, const std::string &name = "");
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
|
@@ -581,6 +581,14 @@ public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class sqrt_inst: public builtin_inst {
|
||||
private:
|
||||
sqrt_inst(value *arg, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "sqrt"; }
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
private:
|
||||
reduce_inst(value* arg, const std::string& name, instruction* next);
|
||||
|
@@ -167,6 +167,15 @@ private:
|
||||
node* arg_;
|
||||
};
|
||||
|
||||
class sqrt_expression: public builtin_expression{
|
||||
public:
|
||||
sqrt_expression(node *arg): arg_(arg) {}
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
node* arg_;
|
||||
};
|
||||
|
||||
class reduce_expression: public builtin_expression{
|
||||
public:
|
||||
reduce_expression(node *arg): arg_(arg) {}
|
||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||
%token IF ELSE FOR CONTINUE WHILE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -123,6 +123,7 @@ builtin_expression
|
||||
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); }
|
||||
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
||||
| REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);}
|
||||
|
@@ -50,6 +50,7 @@ using triton::lang::return_void;
|
||||
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
||||
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
|
||||
"__sum" { return return_impl(REDUCE_SUM, yytext); }
|
||||
"sqrt" { return return_impl(SQRT, yytext); }
|
||||
"dot" { return return_impl(DOT, yytext); }
|
||||
"max" { return return_impl(MAX, yytext); }
|
||||
"min" { return return_impl(MIN, yytext); }
|
||||
|
@@ -380,6 +380,12 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
||||
return (Instruction*)res;
|
||||
}
|
||||
if(ir::sqrt_inst* ii = dynamic_cast<ir::sqrt_inst*>(inst)){
|
||||
Value *val = value(ii->get_operand(0));
|
||||
Value *sqrt = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::sqrt, {val->getType()});
|
||||
Value *res = builder.CreateCall(sqrt, {val});
|
||||
return (Instruction*)res;
|
||||
}
|
||||
// unknown instruction
|
||||
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
|
||||
}
|
||||
@@ -797,7 +803,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn);
|
||||
Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32));
|
||||
Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
|
||||
|
||||
builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)),
|
||||
partial_reduce_do, partial_reduce_done);
|
||||
builder.SetInsertPoint(partial_reduce_do);
|
||||
|
@@ -30,7 +30,10 @@ namespace dnn{
|
||||
* --------------- */
|
||||
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty) { }
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) {
|
||||
DHWB_ = D_*H_*W_*B_;
|
||||
rcpDHWB_ = (float)1 / DHWB_;
|
||||
}
|
||||
|
||||
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
||||
@@ -44,7 +47,9 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(3, x);
|
||||
kernel->setArg(4, g);
|
||||
kernel->setArg(5, b);
|
||||
kernel->setArg(6, (int32_t)(D_*H_*W_*B_));
|
||||
kernel->setArg(6, DHWB_);
|
||||
kernel->setArg(7, rcpDHWB_);
|
||||
kernel->setArg(8, eps_);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
@@ -57,7 +62,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
restrict read_only fp32 *X,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *B,
|
||||
int32 DHWN) {
|
||||
int32 DHWN,
|
||||
fp32 rcpDHWN, fp32 eps) {
|
||||
int32 rx[TM] = 0 ... TM;
|
||||
fp32 *px[TM];
|
||||
fp32 x[TM];
|
||||
@@ -72,9 +78,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
mean = mean + x;
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 m = __sum(mean);
|
||||
fp32 *pm = M + c;
|
||||
*pm = m;
|
||||
*pm = __sum(mean) * rcpDHWN;
|
||||
|
||||
fp32 var[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
@@ -84,9 +89,21 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 v = __sum(var);
|
||||
fp32 v = __sum(var) * rcpDHWN;
|
||||
fp32 *pv = V + c;
|
||||
*pv = v;
|
||||
|
||||
fp32 rstdg = 1 / sqrt(v + eps) * g;
|
||||
|
||||
px = X + rx + c*DHWN;
|
||||
fp32* py[TM] = Y + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
fp32 y[TM] = (x - mean)*rstdg + b;
|
||||
*py = y;
|
||||
px = px + TM;
|
||||
py = py + TM;
|
||||
}
|
||||
})";
|
||||
}
|
||||
|
||||
@@ -148,16 +165,25 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
||||
fp32 dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
db = db + dy;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
}
|
||||
fp32 sdg = __sum(dg);
|
||||
fp32 sdb = __sum(db);
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int32 i = 0; i < DHWN; i += TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
fp32 xhat[TM] = (x - mean) * rstd;
|
||||
fp32 xtmp[TM] = (xhat * dg + db) * NDHW;
|
||||
fp32 dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
pdx = pdx + TM;
|
||||
}
|
||||
})";
|
||||
}
|
||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
std::cout << source << std::endl;
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -320,6 +320,10 @@ value *builder::create_trans(value *A, const std::string &name) {
|
||||
return insert(trans_inst::create(A, name));
|
||||
}
|
||||
|
||||
value *builder::create_sqrt(value *A, const std::string &name) {
|
||||
return insert(sqrt_inst::create(A, name));
|
||||
}
|
||||
|
||||
value *builder::create_reduce(value *A, const std::string &name) {
|
||||
return insert(reduce_inst::create(A, name));
|
||||
}
|
||||
|
@@ -566,6 +566,19 @@ instruction* trans_inst::create(value *arg, const std::string &name, instruction
|
||||
return new trans_inst(arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// sqrt instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
|
||||
: builtin_inst(arg->get_type(), 1, 1, name, next){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new sqrt_inst(arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// reduce instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -180,6 +180,12 @@ ir::value* trans_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_trans(arg_->codegen(mod));
|
||||
}
|
||||
|
||||
// sqrt
|
||||
ir::value* sqrt_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_sqrt(arg_->codegen(mod));
|
||||
}
|
||||
|
||||
|
||||
// reduce
|
||||
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_reduce(arg_->codegen(mod));
|
||||
|
Reference in New Issue
Block a user