diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 2de35d7d6..df37c830c 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -84,6 +84,21 @@ def run_shift(): b: hb})[0] #print(result) + +def batch_norm(x, g, b, epsilon=1e-6): + shape = x.shape + C = int(shape[1]) + assert g.get_shape().num_elements() == C + assert b.get_shape().num_elements() == C + return module.batchnorm_forward(x, g, b, eps=epsilon) + +@ops.RegisterGradient("BatchnormForward") +def batch_norm_grad(op, dy, mean, var): + eps = op.get_attr("eps") + return module.batchnorm_backward(dy, op.inputs[0], op.inputs[1], + op.outputs[1], op.outputs[2], eps=eps) + + def run_batchnorm(): C, H, W, B = 32, 16, 16, 16 np.random.seed(0) @@ -101,7 +116,7 @@ def run_batchnorm(): sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb}) - print(hx.sum(axis=(1,2,3))) - print(result[1]) + + run_batchnorm() diff --git a/include/triton/dnn/batchnorm.h b/include/triton/dnn/batchnorm.h index a61178500..7a97e83af 100644 --- a/include/triton/dnn/batchnorm.h +++ b/include/triton/dnn/batchnorm.h @@ -53,6 +53,9 @@ private: int32_t W_; int32_t B_; std::string ty_; + float eps_; + int32_t DHWB_; + float rcpDHWB_; }; class batchnorm_backward { diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index ea56a8d9c..8dd60fdff 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -131,6 +131,7 @@ public: value *create_atomic_add(value *ptr, value *val, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::string &name = ""); + value *create_sqrt(value *A, const std::string &name = ""); value *create_reduce(value *A, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); // Intrinsics diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 9a8aa2f0b..9828406af 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -581,6 +581,14 @@ public: static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); }; +class sqrt_inst: public builtin_inst { +private: + sqrt_inst(value *arg, const std::string& name, instruction* next); + std::string repr_impl() const { return "sqrt"; } +public: + static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); +}; + class reduce_inst: public builtin_inst { private: reduce_inst(value* arg, const std::string& name, instruction* next); diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index 40e03f84d..3d894c802 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -167,6 +167,15 @@ private: node* arg_; }; +class sqrt_expression: public builtin_expression{ +public: + sqrt_expression(node *arg): arg_(arg) {} + ir::value* codegen(ir::module *) const; + +private: + node* arg_; +}; + class reduce_expression: public builtin_expression{ public: reduce_expression(node *arg): arg_(arg) {} diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 9296cd52f..32b3c5ed4 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64 %token IF ELSE FOR CONTINUE WHILE %token NEWAXIS ELLIPSIS AT -%token GET_GLOBAL_RANGE GET_RANGE_ID DOT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST +%token GET_GLOBAL_RANGE GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST %start translation_unit %% @@ -123,6 +123,7 @@ builtin_expression : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); } | GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } + | SQRT '(' expression ')' { $$ = new sqrt_expression($3); } | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } | TRANS '(' expression ')' { $$ = new trans_expression($3); } | REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);} diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index 68c38d2f0..0fbaa52d2 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -50,6 +50,7 @@ using triton::lang::return_void; "__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); } "__atomic_add" { return return_impl(ATOMIC_ADD, yytext); } "__sum" { return return_impl(REDUCE_SUM, yytext); } +"sqrt" { return return_impl(SQRT, yytext); } "dot" { return return_impl(DOT, yytext); } "max" { return return_impl(MAX, yytext); } "min" { return return_impl(MIN, yytext); } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index da31ea9d8..3bd010ebc 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -380,6 +380,12 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function(inst)){ + Value *val = value(ii->get_operand(0)); + Value *sqrt = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::sqrt, {val->getType()}); + Value *res = builder.CreateCall(sqrt, {val}); + return (Instruction*)res; + } // unknown instruction throw std::runtime_error("unknown conversion from ir::instruction to Instruction"); } @@ -797,7 +803,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn); Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32)); Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32)); - builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)), partial_reduce_do, partial_reduce_done); builder.SetInsertPoint(partial_reduce_do); diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp index 4bb29db5d..e3b1a630c 100644 --- a/lib/dnn/batchnorm.cpp +++ b/lib/dnn/batchnorm.cpp @@ -30,7 +30,10 @@ namespace dnn{ * --------------- */ batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty) - : C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty) { } + : C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(1e-5) { + DHWB_ = D_*H_*W_*B_; + rcpDHWB_ = (float)1 / DHWB_; +} void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel, driver::buffer *y, driver::buffer *m, driver::buffer *v, @@ -44,7 +47,9 @@ void batchnorm_forward::enqueue(driver::stream *stream, driver::kernel *kernel, kernel->setArg(3, x); kernel->setArg(4, g); kernel->setArg(5, b); - kernel->setArg(6, (int32_t)(D_*H_*W_*B_)); + kernel->setArg(6, DHWB_); + kernel->setArg(7, rcpDHWB_); + kernel->setArg(8, eps_); stream->enqueue(kernel, grid, {nthreads, 1, 1}); } @@ -57,7 +62,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, restrict read_only fp32 *X, restrict read_only fp32 *G, restrict read_only fp32 *B, - int32 DHWN) { + int32 DHWN, + fp32 rcpDHWN, fp32 eps) { int32 rx[TM] = 0 ... TM; fp32 *px[TM]; fp32 x[TM]; @@ -72,9 +78,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, mean = mean + x; px = px + TM; } - fp32 m = __sum(mean); fp32 *pm = M + c; - *pm = m; + *pm = __sum(mean) * rcpDHWN; fp32 var[TM] = 0; px = X + rx + c*DHWN; @@ -84,9 +89,21 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, var = var + x*x; px = px + TM; } - fp32 v = __sum(var); + fp32 v = __sum(var) * rcpDHWN; fp32 *pv = V + c; *pv = v; + + fp32 rstdg = 1 / sqrt(v + eps) * g; + + px = X + rx + c*DHWN; + fp32* py[TM] = Y + rx + c*DHWN; + for(int32 i = 0; i < DHWN; i = i + TM){ + x = *px; + fp32 y[TM] = (x - mean)*rstdg + b; + *py = y; + px = px + TM; + py = py + TM; + } })"; } @@ -148,16 +165,25 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB, fp32 dy[TM] = *pdy; dg = dg + dy*(x - mean)*rstd; db = db + dy; + px = px + TM; + pdy = pdy + TM; } + fp32 sdg = __sum(dg); + fp32 sdb = __sum(db); px = X + rx + offset; pdy = DY + rx + offset; pdx = DX + rx + offset; for(int32 i = 0; i < DHWN; i += TM){ + fp32 x[TM] = *px; + fp32 dy[TM] = *pdy; fp32 xhat[TM] = (x - mean) * rstd; fp32 xtmp[TM] = (xhat * dg + db) * NDHW; fp32 dx[TM] = (dy - xtmp) * rstd * g; *pdx = dx; + px = px + TM; + pdy = pdy + TM; + pdx = pdx + TM; } })"; } diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index c3139ece6..4ff863666 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ - std::cout << source << std::endl; +// std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index cf6832958..54321bd81 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -320,6 +320,10 @@ value *builder::create_trans(value *A, const std::string &name) { return insert(trans_inst::create(A, name)); } +value *builder::create_sqrt(value *A, const std::string &name) { + return insert(sqrt_inst::create(A, name)); +} + value *builder::create_reduce(value *A, const std::string &name) { return insert(reduce_inst::create(A, name)); } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 6607990c4..27efc0838 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -566,6 +566,19 @@ instruction* trans_inst::create(value *arg, const std::string &name, instruction return new trans_inst(arg, name, next); } +//===----------------------------------------------------------------------===// +// sqrt instructions +//===----------------------------------------------------------------------===// + +sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next) + : builtin_inst(arg->get_type(), 1, 1, name, next){ + set_operand(0, arg); +} + +instruction* sqrt_inst::create(value *arg, const std::string &name, instruction *next) { + return new sqrt_inst(arg, name, next); +} + //===----------------------------------------------------------------------===// // reduce instructions //===----------------------------------------------------------------------===// diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index 8f51ec47a..85e98a771 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -180,6 +180,12 @@ ir::value* trans_expression::codegen(ir::module *mod) const { return mod->get_builder().create_trans(arg_->codegen(mod)); } +// sqrt +ir::value* sqrt_expression::codegen(ir::module *mod) const { + return mod->get_builder().create_sqrt(arg_->codegen(mod)); +} + + // reduce ir::value* reduce_expression::codegen(ir::module *mod) const { return mod->get_builder().create_reduce(arg_->codegen(mod));