[dnn/batchnorm]: added some more code in Triton-C batchnorm implementations
This commit is contained in:
@@ -84,6 +84,21 @@ def run_shift():
|
|||||||
b: hb})[0]
|
b: hb})[0]
|
||||||
#print(result)
|
#print(result)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_norm(x, g, b, epsilon=1e-6):
|
||||||
|
shape = x.shape
|
||||||
|
C = int(shape[1])
|
||||||
|
assert g.get_shape().num_elements() == C
|
||||||
|
assert b.get_shape().num_elements() == C
|
||||||
|
return module.batchnorm_forward(x, g, b, eps=epsilon)
|
||||||
|
|
||||||
|
@ops.RegisterGradient("BatchnormForward")
|
||||||
|
def batch_norm_grad(op, dy, mean, var):
|
||||||
|
eps = op.get_attr("eps")
|
||||||
|
return module.batchnorm_backward(dy, op.inputs[0], op.inputs[1],
|
||||||
|
op.outputs[1], op.outputs[2], eps=eps)
|
||||||
|
|
||||||
|
|
||||||
def run_batchnorm():
|
def run_batchnorm():
|
||||||
C, H, W, B = 32, 16, 16, 16
|
C, H, W, B = 32, 16, 16, 16
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
@@ -101,7 +116,7 @@ def run_batchnorm():
|
|||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||||
print(hx.sum(axis=(1,2,3)))
|
|
||||||
print(result[1])
|
|
||||||
|
|
||||||
run_batchnorm()
|
run_batchnorm()
|
||||||
|
@@ -53,6 +53,9 @@ private:
|
|||||||
int32_t W_;
|
int32_t W_;
|
||||||
int32_t B_;
|
int32_t B_;
|
||||||
std::string ty_;
|
std::string ty_;
|
||||||
|
float eps_;
|
||||||
|
int32_t DHWB_;
|
||||||
|
float rcpDHWB_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class batchnorm_backward {
|
class batchnorm_backward {
|
||||||
|
@@ -131,6 +131,7 @@ public:
|
|||||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||||
value *create_trans(value *A, const std::string &name = "");
|
value *create_trans(value *A, const std::string &name = "");
|
||||||
|
value *create_sqrt(value *A, const std::string &name = "");
|
||||||
value *create_reduce(value *A, const std::string &name = "");
|
value *create_reduce(value *A, const std::string &name = "");
|
||||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||||
// Intrinsics
|
// Intrinsics
|
||||||
|
@@ -581,6 +581,14 @@ public:
|
|||||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class sqrt_inst: public builtin_inst {
|
||||||
|
private:
|
||||||
|
sqrt_inst(value *arg, const std::string& name, instruction* next);
|
||||||
|
std::string repr_impl() const { return "sqrt"; }
|
||||||
|
public:
|
||||||
|
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
class reduce_inst: public builtin_inst {
|
class reduce_inst: public builtin_inst {
|
||||||
private:
|
private:
|
||||||
reduce_inst(value* arg, const std::string& name, instruction* next);
|
reduce_inst(value* arg, const std::string& name, instruction* next);
|
||||||
|
@@ -167,6 +167,15 @@ private:
|
|||||||
node* arg_;
|
node* arg_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class sqrt_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
sqrt_expression(node *arg): arg_(arg) {}
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
node* arg_;
|
||||||
|
};
|
||||||
|
|
||||||
class reduce_expression: public builtin_expression{
|
class reduce_expression: public builtin_expression{
|
||||||
public:
|
public:
|
||||||
reduce_expression(node *arg): arg_(arg) {}
|
reduce_expression(node *arg): arg_(arg) {}
|
||||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
|||||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||||
%token IF ELSE FOR CONTINUE WHILE
|
%token IF ELSE FOR CONTINUE WHILE
|
||||||
%token NEWAXIS ELLIPSIS AT
|
%token NEWAXIS ELLIPSIS AT
|
||||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST
|
||||||
|
|
||||||
%start translation_unit
|
%start translation_unit
|
||||||
%%
|
%%
|
||||||
@@ -123,6 +123,7 @@ builtin_expression
|
|||||||
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); }
|
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); }
|
||||||
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
||||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||||
|
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||||
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
||||||
| REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);}
|
| REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);}
|
||||||
|
@@ -50,6 +50,7 @@ using triton::lang::return_void;
|
|||||||
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
"__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); }
|
||||||
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
|
"__atomic_add" { return return_impl(ATOMIC_ADD, yytext); }
|
||||||
"__sum" { return return_impl(REDUCE_SUM, yytext); }
|
"__sum" { return return_impl(REDUCE_SUM, yytext); }
|
||||||
|
"sqrt" { return return_impl(SQRT, yytext); }
|
||||||
"dot" { return return_impl(DOT, yytext); }
|
"dot" { return return_impl(DOT, yytext); }
|
||||||
"max" { return return_impl(MAX, yytext); }
|
"max" { return return_impl(MAX, yytext); }
|
||||||
"min" { return return_impl(MIN, yytext); }
|
"min" { return return_impl(MIN, yytext); }
|
||||||
|
@@ -380,6 +380,12 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
|||||||
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
||||||
return (Instruction*)res;
|
return (Instruction*)res;
|
||||||
}
|
}
|
||||||
|
if(ir::sqrt_inst* ii = dynamic_cast<ir::sqrt_inst*>(inst)){
|
||||||
|
Value *val = value(ii->get_operand(0));
|
||||||
|
Value *sqrt = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::sqrt, {val->getType()});
|
||||||
|
Value *res = builder.CreateCall(sqrt, {val});
|
||||||
|
return (Instruction*)res;
|
||||||
|
}
|
||||||
// unknown instruction
|
// unknown instruction
|
||||||
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
|
throw std::runtime_error("unknown conversion from ir::instruction to Instruction");
|
||||||
}
|
}
|
||||||
@@ -797,7 +803,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn);
|
BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn);
|
||||||
Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32));
|
Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32));
|
||||||
Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
|
Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32));
|
||||||
|
|
||||||
builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)),
|
builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)),
|
||||||
partial_reduce_do, partial_reduce_done);
|
partial_reduce_do, partial_reduce_done);
|
||||||
builder.SetInsertPoint(partial_reduce_do);
|
builder.SetInsertPoint(partial_reduce_do);
|
||||||
|
@@ -30,7 +30,10 @@ namespace dnn{
|
|||||||
* --------------- */
|
* --------------- */
|
||||||
|
|
||||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
|
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty)
|
||||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty) { }
|
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) {
|
||||||
|
DHWB_ = D_*H_*W_*B_;
|
||||||
|
rcpDHWB_ = (float)1 / DHWB_;
|
||||||
|
}
|
||||||
|
|
||||||
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||||
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
driver::buffer *y, driver::buffer *m, driver::buffer *v,
|
||||||
@@ -44,7 +47,9 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
|||||||
kernel->setArg(3, x);
|
kernel->setArg(3, x);
|
||||||
kernel->setArg(4, g);
|
kernel->setArg(4, g);
|
||||||
kernel->setArg(5, b);
|
kernel->setArg(5, b);
|
||||||
kernel->setArg(6, (int32_t)(D_*H_*W_*B_));
|
kernel->setArg(6, DHWB_);
|
||||||
|
kernel->setArg(7, rcpDHWB_);
|
||||||
|
kernel->setArg(8, eps_);
|
||||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +62,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
|||||||
restrict read_only fp32 *X,
|
restrict read_only fp32 *X,
|
||||||
restrict read_only fp32 *G,
|
restrict read_only fp32 *G,
|
||||||
restrict read_only fp32 *B,
|
restrict read_only fp32 *B,
|
||||||
int32 DHWN) {
|
int32 DHWN,
|
||||||
|
fp32 rcpDHWN, fp32 eps) {
|
||||||
int32 rx[TM] = 0 ... TM;
|
int32 rx[TM] = 0 ... TM;
|
||||||
fp32 *px[TM];
|
fp32 *px[TM];
|
||||||
fp32 x[TM];
|
fp32 x[TM];
|
||||||
@@ -72,9 +78,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
|||||||
mean = mean + x;
|
mean = mean + x;
|
||||||
px = px + TM;
|
px = px + TM;
|
||||||
}
|
}
|
||||||
fp32 m = __sum(mean);
|
|
||||||
fp32 *pm = M + c;
|
fp32 *pm = M + c;
|
||||||
*pm = m;
|
*pm = __sum(mean) * rcpDHWN;
|
||||||
|
|
||||||
fp32 var[TM] = 0;
|
fp32 var[TM] = 0;
|
||||||
px = X + rx + c*DHWN;
|
px = X + rx + c*DHWN;
|
||||||
@@ -84,9 +89,21 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
|||||||
var = var + x*x;
|
var = var + x*x;
|
||||||
px = px + TM;
|
px = px + TM;
|
||||||
}
|
}
|
||||||
fp32 v = __sum(var);
|
fp32 v = __sum(var) * rcpDHWN;
|
||||||
fp32 *pv = V + c;
|
fp32 *pv = V + c;
|
||||||
*pv = v;
|
*pv = v;
|
||||||
|
|
||||||
|
fp32 rstdg = 1 / sqrt(v + eps) * g;
|
||||||
|
|
||||||
|
px = X + rx + c*DHWN;
|
||||||
|
fp32* py[TM] = Y + rx + c*DHWN;
|
||||||
|
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||||
|
x = *px;
|
||||||
|
fp32 y[TM] = (x - mean)*rstdg + b;
|
||||||
|
*py = y;
|
||||||
|
px = px + TM;
|
||||||
|
py = py + TM;
|
||||||
|
}
|
||||||
})";
|
})";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,16 +165,25 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
|||||||
fp32 dy[TM] = *pdy;
|
fp32 dy[TM] = *pdy;
|
||||||
dg = dg + dy*(x - mean)*rstd;
|
dg = dg + dy*(x - mean)*rstd;
|
||||||
db = db + dy;
|
db = db + dy;
|
||||||
|
px = px + TM;
|
||||||
|
pdy = pdy + TM;
|
||||||
}
|
}
|
||||||
|
fp32 sdg = __sum(dg);
|
||||||
|
fp32 sdb = __sum(db);
|
||||||
|
|
||||||
px = X + rx + offset;
|
px = X + rx + offset;
|
||||||
pdy = DY + rx + offset;
|
pdy = DY + rx + offset;
|
||||||
pdx = DX + rx + offset;
|
pdx = DX + rx + offset;
|
||||||
for(int32 i = 0; i < DHWN; i += TM){
|
for(int32 i = 0; i < DHWN; i += TM){
|
||||||
|
fp32 x[TM] = *px;
|
||||||
|
fp32 dy[TM] = *pdy;
|
||||||
fp32 xhat[TM] = (x - mean) * rstd;
|
fp32 xhat[TM] = (x - mean) * rstd;
|
||||||
fp32 xtmp[TM] = (xhat * dg + db) * NDHW;
|
fp32 xtmp[TM] = (xhat * dg + db) * NDHW;
|
||||||
fp32 dx[TM] = (dy - xtmp) * rstd * g;
|
fp32 dx[TM] = (dy - xtmp) * rstd * g;
|
||||||
*pdx = dx;
|
*pdx = dx;
|
||||||
|
px = px + TM;
|
||||||
|
pdy = pdy + TM;
|
||||||
|
pdx = pdx + TM;
|
||||||
}
|
}
|
||||||
})";
|
})";
|
||||||
}
|
}
|
||||||
|
@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
|||||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
std::cout << source << std::endl;
|
// std::cout << source << std::endl;
|
||||||
cu_context::context_switcher ctx_switch(*context);
|
cu_context::context_switcher ctx_switch(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -320,6 +320,10 @@ value *builder::create_trans(value *A, const std::string &name) {
|
|||||||
return insert(trans_inst::create(A, name));
|
return insert(trans_inst::create(A, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_sqrt(value *A, const std::string &name) {
|
||||||
|
return insert(sqrt_inst::create(A, name));
|
||||||
|
}
|
||||||
|
|
||||||
value *builder::create_reduce(value *A, const std::string &name) {
|
value *builder::create_reduce(value *A, const std::string &name) {
|
||||||
return insert(reduce_inst::create(A, name));
|
return insert(reduce_inst::create(A, name));
|
||||||
}
|
}
|
||||||
|
@@ -566,6 +566,19 @@ instruction* trans_inst::create(value *arg, const std::string &name, instruction
|
|||||||
return new trans_inst(arg, name, next);
|
return new trans_inst(arg, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// sqrt instructions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
|
||||||
|
: builtin_inst(arg->get_type(), 1, 1, name, next){
|
||||||
|
set_operand(0, arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||||
|
return new sqrt_inst(arg, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// reduce instructions
|
// reduce instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -180,6 +180,12 @@ ir::value* trans_expression::codegen(ir::module *mod) const {
|
|||||||
return mod->get_builder().create_trans(arg_->codegen(mod));
|
return mod->get_builder().create_trans(arg_->codegen(mod));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sqrt
|
||||||
|
ir::value* sqrt_expression::codegen(ir::module *mod) const {
|
||||||
|
return mod->get_builder().create_sqrt(arg_->codegen(mod));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// reduce
|
// reduce
|
||||||
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
ir::value* reduce_expression::codegen(ir::module *mod) const {
|
||||||
return mod->get_builder().create_reduce(arg_->codegen(mod));
|
return mod->get_builder().create_reduce(arg_->codegen(mod));
|
||||||
|
Reference in New Issue
Block a user