From d65a94c76843c53f7722949a493d7f77bfed814a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 29 Oct 2019 17:29:11 -0400 Subject: [PATCH] [PYTHON][OPS] Added batch normalization op --- include/triton/runtime/function.h | 2 + lib/codegen/selection/generator.cc | 5 ++ lib/lang/code_gen.cc | 14 ++++-- lib/runtime/function.cc | 7 ++- python/setup.py | 2 +- python/src/bindings.cc | 32 +++---------- python/triton/ops/__init__.py | 1 + python/triton/ops/batchnorm.py | 75 ++++++++++++++++++++++++++++++ python/triton/ops/einsum.py | 4 +- 9 files changed, 108 insertions(+), 34 deletions(-) create mode 100644 python/triton/ops/batchnorm.py diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 539de8684..a6ab851a9 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -99,6 +99,8 @@ private: std::unique_ptr make_bin(ir::module &function, driver::context *context, const options_t &opt); caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector &args); +public: + static std::string preheader(); public: function(const std::string& src, const options_space_t& opt = options_space_t()); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 2efa834a8..34bf52b7f 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -289,6 +289,11 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { + if(!x->get_type()->is_tile_ty()){ + Value *ptr = get_value(x->get_pointer_operand(), {}); + set_value(x, {}, builder_->CreateLoad(ptr)); + return; + } // find vector size ir::value *ptr = x->get_pointer_operand(); size_t ld = layouts_->get(ptr)->order[0]; diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index fdc754048..d13f68856 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -229,6 +229,13 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { ir::value* ret = ret_; if(auto axis = dynamic_cast(ret)) return set_ret(bld_->create_get_program_id(axis->get_value())); + else + return should_not_happen(); + } + if(name == "sqrtf"){ + VisitExpr(funcCall->Args()->at(0)); + ir::value* ret = ret_; + return set_ret(bld_->create_sqrt(ret)); } return error_not_implemented(); } @@ -274,10 +281,11 @@ void Generator::VisitDeclaration(Declaration* decl) { // initialize declaration ir::type::id_t id = ty->get_type_id(); if(id == ir::type::StructTyID) - assert(false); + should_not_happen(); if(inits.size() > 1) - assert(false); - val = inits[0]; + should_not_happen(); + if(inits.size() > 0) + val = inits[0]; assert(val->get_type() == ty); // update scope symbols table const std::string &name = obj->Name(); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index a7072b757..d955b69f0 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -113,7 +113,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, arg arg_i = args.at(i); arg_type ty = arg_i.type(); if(ty != param_tys_.at(i)) - throw std::runtime_error("invalid type"); + throw std::runtime_error("invalid type for argument " + std::to_string(i)); if(ty == BUFFER_T) bin_->setArg(i, *((driver::buffer**)arg_i.data())); else @@ -253,16 +253,14 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c return std::unique_ptr(); barriers.run(module); // std::cout << "isel" << std::endl; -// ir::print(module, std::cout); isel.visit(module, *llvm); -// std::cout << "done" << std::endl; // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); // done return res; } -std::string preheader() { +std::string function::preheader() { return R"( #define bool _Bool @@ -277,6 +275,7 @@ R"( #define __multipleof(A) __attribute__((multipleof(A))) extern int get_program_id(int); +extern float sqrtf(float); )"; } diff --git a/python/setup.py b/python/setup.py index 060a1c450..4c8d38259 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,7 +77,7 @@ class CMakeBuild(build_ext): pass cfg = 'Debug' if self.debug else 'Release' - cfg = 'Release' + #cfg = 'Release' build_args = ['--config', cfg] if platform.system() == "Windows": diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 80e2f8ddc..37aa0a2c8 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -211,27 +211,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name, os << ";\n"; } -inline std::string preheader() { -return -R"( -#define bool _Bool -#define true 1 -#define false 0 -#define __bool_true_false_are_defined 1 - -#define __readonly __attribute__((readonly)) -#define __writeonly __attribute__((writeonly)) -#define __noalias __attribute__((noalias)) -#define __aligned(A) __attribute__((aligned(A))) -#define __multipleof(A) __attribute__((multipleof(A))) - -extern int get_program_id(int); -)"; -} - void make_module(const std::string& src, ir::module* ir, const runtime::function::options_space_t& opt) { - std::string copy = preheader() + src; + std::string copy = triton::runtime::function::preheader() + src; // pre-process TokenSequence tokens; Preprocessor cpp(©, true); @@ -341,11 +323,11 @@ inline std::string to_torch_ty(ir::type *ty) { if(ty->is_integer_ty()) return "int64_t"; if(ty->is_half_ty()) - return "float16"; + return "double"; if(ty->is_float_ty()) - return "float32"; + return "double"; if(ty->is_double_ty()) - return "float64"; + return "double"; if(ty->is_pointer_ty()) return "torch::Tensor"; throw std::runtime_error("unknown type"); @@ -363,11 +345,11 @@ inline std::string to_c_ty(ir::type *ty) { if(ty->is_integer_ty(64)) return "int64_t"; if(ty->is_half_ty()) - return "float16"; + return "half"; if(ty->is_float_ty()) - return "float32"; + return "float"; if(ty->is_double_ty()) - return "float64"; + return "double"; if(ty->is_pointer_ty()) return "drv::cu_buffer"; throw std::runtime_error("unknown type"); diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py index f409fde46..ac0c3293d 100644 --- a/python/triton/ops/__init__.py +++ b/python/triton/ops/__init__.py @@ -1,2 +1,3 @@ from .dot import _dot, dot from .einsum import _einsum, einsum +from .batchnorm import _batchnorm, batchnorm \ No newline at end of file diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py new file mode 100644 index 000000000..fb6e375e2 --- /dev/null +++ b/python/triton/ops/batchnorm.py @@ -0,0 +1,75 @@ +import triton +import math + +class _batchnorm(triton.function): + + fwd_src = """ +void batchnormForward(float *Y, float *M, float *V, + float *X, float *G, float *B, + int N, float rcpN, float eps) { + int rx[TM] = 0 ... TM; + float *px[TM]; + float x[TM] = 0; + int c = get_program_id(1); + float g = *(G + c); + float b = *(B + c); + + float mean[TM] = 0; + px = X + rx + c*N; + for(int i = 0; i < N; i = i + TM){ + x = *px; + mean = mean + x; + px = px + TM; + } + float *pm = M + c; + float m = mean[+] * rcpN; + *pm = m; + + float var[TM] = 0; + px = X + rx + c*N; + for(int i = 0; i < N; i = i + TM){ + x = *px; + x = x - m; + var = var + x*x; + px = px + TM; + } + float v = var[+] * rcpN; + float *pv = V + c; + *pv = v; + float rstdg = 1 / sqrtf(v + eps) * g; + + px = X + rx + c*N; + float* py[TM] = Y + rx + c*N; + for(int i = 0; i < N; i = i + TM){ + x = *px; + float y[TM] = (x - m)*rstdg + b; + *py = y; + px = px + TM; + py = py + TM; + } +} +""" + + fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V']) + + @staticmethod + def forward(ctx, x, gamma, beta, eps): + shape = triton.shape(x) + dtype = x.dtype + # allocate outputs + C, H, W, B = shape[0], shape[1], shape[2], shape[3] + y = triton.empty(shape, dtype=dtype) + mean = triton.empty([C], dtype=dtype) + var = triton.empty([C], dtype=dtype) + # execute kernels + N = H*W*B + _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps, + lambda opt: [1, C], + TM = 128) + # save + ctx.eps = eps + ctx.save_for_backward(x, gamma, beta, mean, var) + return y, mean, var + + +batchnorm = _batchnorm.apply \ No newline at end of file diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 1467b1173..167b2aacd 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -181,15 +181,16 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C, @staticmethod def forward(ctx, subscripts, a, b, bench = 0): ctx.save_for_backward(a, b) + # parse if type(subscripts) is str: einsum_a, einsum_bc = subscripts.split(",") einsum_b, einsum_c = einsum_bc.split("->") else: einsum_a, einsum_b, einsum_c = subscripts - shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum( einsum_a, einsum_b, einsum_c, triton.shape(a), triton.shape(b)) + # save for backward ctx.trans_a = ta ctx.trans_b = tb ctx.einsum_a = einsum_a @@ -197,6 +198,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C, ctx.einsum_c = einsum_c ctx.bench = bench ctx.bmnk = bmnk + # run return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)