From 2b9355c9e4e193bf8dbebcd8e4f1be84367096a8 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 30 Oct 2019 01:38:30 -0400 Subject: [PATCH] [PYTHON][TENSORFLOW] Got rid of alloc_empty entirely; now doing generating allocation code inside the tensorflow op --- python/src/bindings.cc | 99 +++++++++++++++++++++------- python/src/tensorflow/alloc_empty.cc | 1 + python/triton/function.py | 7 +- python/triton/kernel.py | 12 +++- python/triton/ops/batchnorm.py | 74 ++++++++++++++++++++- python/triton/utils.py | 15 ++--- 6 files changed, 164 insertions(+), 44 deletions(-) diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 37aa0a2c8..b3b74b37b 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -89,9 +89,9 @@ inline std::string to_tf_ty(ir::type *ty) { if(ty->is_half_ty()) return "float16"; if(ty->is_float_ty()) - return "float32"; + return "float"; if(ty->is_double_ty()) - return "float64"; + return "double"; if(ty->is_pointer_ty()) return "Tensor"; throw std::runtime_error("unknown type"); @@ -113,21 +113,50 @@ inline std::string ref_to_tf_ty(ir::type *ty) { } -void gen_extract_inputs(std::ostream &os, const std::vector& args) { +void gen_extract_inputs(std::ostream &os, const std::vector& args, const std::vector& outputs) { for(unsigned i = 0; i < args.size(); i++){ ir::value *arg = args[i]; - std::string suffix = ""; - ir::type *tr_ty = arg->get_type(); - std::string tf_ty = ref_to_tf_ty(tr_ty); - if(!tr_ty->is_pointer_ty()) - suffix = ".scalar<" + tf_ty + ">()()"; - os << " " << tf_ty << " " << arg->get_name() << " = context->input(" << i << ")" << suffix << ";\n "; + const std::string& name = arg->get_name(); + std::string ty = to_tf_ty(arg->get_type()); + if(!arg->get_type()->is_pointer_ty()) + os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n "; + else if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end()) + os << " const Tensor* " << name << " = &context->input(" << i << ");\n "; + else + os << " Tensor* " << name << " = nullptr;\n "; } } -void gen_set_outputs(std::ostream &os, const std::vector& outputs) { +void gen_set_outputs(std::ostream &os, const std::vector& args, const std::vector& outputs) { for(unsigned i = 0; i < outputs.size(); i++) - os << " context->set_output(" << i << ", " << outputs[i] << ");\n "; + os << " TensorShape shape" << i << ";\n "; + // initialize shapes + + std::vector out_idx; + for(size_t i = 0; i < outputs.size(); i++){ + std::string name = outputs[i]; + size_t idx; + for(idx = 0; idx < args.size(); idx++) + if(args[idx]->get_name() == name) + break; + if(idx == args.size()) + throw std::runtime_error("unknown output"); + out_idx.push_back(idx); + } + + for(unsigned i = 0; i < outputs.size(); i++) + os << " const Tensor& " << outputs[i] << "_shape = context->input(" << out_idx[i] << ");\n "; + for(unsigned i = 0; i < outputs.size(); i++) + os << " const int32* " << outputs[i] << "_shape_data = (const int32*)" << outputs[i] << "_shape.tensor_data().data();\n "; + for(unsigned i = 0; i < outputs.size(); i++) + os << " size_t " << outputs[i] << "_rank = " << outputs[i] << "_shape.dim_size(0);\n "; + for(unsigned i = 0; i < outputs.size(); i++) + os << " for(size_t d = 0; d < " << outputs[i] << "_rank ; d++) " + << "shape" << i << ".AddDim(" << outputs[i] << "_shape_data[d]);\n "; + + // allocate + for(unsigned i = 0; i < outputs.size(); i++) + os << " OP_REQUIRES_OK(context, context->allocate_output(" << i << ", shape" << i << ", &" << outputs[i] << "));\n "; } void gen_make_handles(std::ostream &os, const std::vector& args) { @@ -136,7 +165,7 @@ void gen_make_handles(std::ostream &os, const std::vector& args) if(!arg->get_type()->is_pointer_ty()) continue; const std::string& name = arg->get_name(); - os << " drv::cu_buffer cu_" + name + "(ctx, " + name + ".tensor_data().size(), (CUdeviceptr)" + name + ".tensor_data().data(), false);\n "; + os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->tensor_data().size(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n "; } } @@ -161,7 +190,8 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name, const std::string &opname, - const std::vector& args){ + const std::vector& args, + const std::vector& outputs){ os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)"; for(size_t i = 0; i < args.size(); i++){ ir::argument *arg = args[i]; @@ -171,20 +201,31 @@ void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name, if(!arg->get_type()->is_pointer_ty()) os << ".HostMemory(\"" + name + "\")"; } + for(size_t i = 0; i < outputs.size(); i++){ + std::string name = outputs[i]; + name[0] = std::tolower(name[0]); + os << ".HostMemory(\"" << name << "_shape\")"; + } os << ", " + opname << ");\n"; } void gen_tf_register_op(std::ostream &os, const std::string &name, const std::vector& args, const std::vector& outputs){ + + auto tolower = [](char c) { return std::tolower(c);}; + os << "REGISTER_OP(\"" << name << "\")\n"; + for(size_t i = 0; i < args.size(); i++) + os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl; for(size_t i = 0; i < args.size(); i++){ ir::argument *arg = args[i]; std::string name = arg->get_name(); - auto tolower = [](char c) { return std::tolower(c);}; std::transform(name.begin(), name.end(), name.begin(), tolower); - os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl; - os << " .Input(\"" << name << ": T" << i << "\")\n"; + if(std::find(outputs.begin(), outputs.end(), arg->get_name()) == outputs.end()) + os << " .Input(\"" << name << ": T" << i << "\")\n"; + else + os << " .Input(\"" << name << "_shape: int32\")\n"; } std::vector out_idx; for(size_t i = 0; i < outputs.size(); i++){ @@ -197,15 +238,22 @@ void gen_tf_register_op(std::ostream &os, const std::string &name, throw std::runtime_error("unknown output"); out_idx.push_back(idx); } - for(size_t i = 0; i < out_idx.size(); i++) - os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n"; + for(size_t i = 0; i < out_idx.size(); i++){ + std::string name = outputs[i]; + std::transform(name.begin(), name.end(), name.begin(), tolower); + os << " .Output(\"" << name << ": T" << out_idx[i] << "\")\n"; + } os << " .Attr(\"id: int\")\n"; os << " .Attr(\"bench: int\")\n"; os << " .Attr(\"bench_id: int\")\n"; - os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n"; + os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n"; for(size_t i = 0; i < out_idx.size(); i++) - os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n"; - os << " return Status::OK();\n"; + os << " shape_inference::ShapeHandle handle" << i << ";\n"; + for(size_t i = 0; i < out_idx.size(); i++) + os << " ctx->MakeShapeFromShapeTensor(" << out_idx[i] << ", &handle" << i << ");\n"; + for(size_t i = 0; i < out_idx.size(); i++) + os << " ctx->set_output(" << i << ", handle" << i << ");\n"; + os << " return Status::OK();\n"; os << " })\n"; os << ";\n"; @@ -237,6 +285,7 @@ std::tuple(new ir::module("", ctx)); make_module(src, &*ir, opt); + // function ir::function* fn = ir->get_function_list().front(); std::string name = fn->get_name(); @@ -276,18 +325,20 @@ class )" << opname << R"(: public OpKernel { } void Compute(OpKernelContext* context){ + // get device/stream GPUDevice device = context->eigen_device(); drv::cu_stream sstream(device.stream(), false); drv::context* ctx = sstream.context(); drv::stream* stream = &sstream; + // extract inputs )"; -gen_extract_inputs(oss, fn->args()); +gen_extract_inputs(oss, fn->args(), outputs); oss << R"( // set outputs )"; -gen_set_outputs(oss, outputs); +gen_set_outputs(oss, fn->args(), outputs); oss << R"( // wrap tensors )"; @@ -309,7 +360,7 @@ private: // register kernel builder )"; -gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args()); +gen_tf_register_kernel_builder(oss, cc_name, opname, fn->args(), outputs); oss << R"( // register op )"; diff --git a/python/src/tensorflow/alloc_empty.cc b/python/src/tensorflow/alloc_empty.cc index 43f82cbfa..a9c97b1d5 100644 --- a/python/src/tensorflow/alloc_empty.cc +++ b/python/src/tensorflow/alloc_empty.cc @@ -8,6 +8,7 @@ class AllocEmptyOp : public OpKernel { explicit AllocEmptyOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { + std::cout << "executing allocempty" << std::endl; // fetch input const Tensor& x = context->input(0); const int32* x_data = (const int32*)x.tensor_data().data(); diff --git a/python/triton/function.py b/python/triton/function.py index f40605ea9..e75512b1b 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -44,12 +44,7 @@ class function(metaclass = function_meta): @classmethod def apply_tensorflow(cls, *args, **kwargs): ctx = OpContext() - # Acquire a mutex here to ensure that calls to alloc_empty() - # are handled properly - mutex = fw.gen_resource_variable_ops.mutex_v2() - lock = fw.gen_resource_variable_ops.mutex_lock(mutex) - with fw.tensorflow.control_dependencies([lock]): - result = cls.forward(ctx, *args, **kwargs) + result = cls.forward(ctx, *args, **kwargs) # Find a mapping between ::forward arguments and tensorflow op arguments remap = dict() for i, ix in enumerate(result.op.inputs): diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 769e47a29..60964abdc 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -222,11 +222,17 @@ class kernel: libtriton.register_grid(op_id, _make_grid(args)) # id for the benchmark result bench_id = libtriton.make_scalar_id() if bench > 0 else -1 - # create operands # call framework function if fw.has_tensorflow(): - args = [x for x in args[:-1]] - ret = self.fw_op(*args, id=op_id, bench=bench, bench_id=bench_id) + # operands + operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]] + # output data types + kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id} + for i, x in enumerate(args[:-1]): + if isinstance(x, triton.utils.tf_empty_proxy): + kwargs['T' + str(i)] = x.dtype + # launch + ret = self.fw_op(*operands, **kwargs) if bench > 0: bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id) elif fw.has_torch(): diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index fb6e375e2..5e352d93a 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -4,7 +4,7 @@ import math class _batchnorm(triton.function): fwd_src = """ -void batchnormForward(float *Y, float *M, float *V, +void fwdbatchnorm(float *Y, float *M, float *V, float *X, float *G, float *B, int N, float rcpN, float eps) { int rx[TM] = 0 ... TM; @@ -52,6 +52,58 @@ void batchnormForward(float *Y, float *M, float *V, fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V']) + bwd_src = """ +void batchnormBackward(float *DX, float *DG, float *DB, + float *DY, float *X, float *G, + float *M, float *V, + int DHWN, float rcpDHWN, float epsilon) { + int rx[TM] = 0 ... TM; + int c = get_program_id(1); + int offset = c*DHWN; + float g = *(G + c); + float mean = *(M + c); + float var = *(V + c); + float rstd = 1 / sqrtf(var + epsilon); + float* px[TM]; + float* pdx[TM]; + float* pdy[TM]; + px = X + rx + offset; + pdy = DY + rx + offset; + float dg[TM] = 0; + float db[TM] = 0; + for(int i = 0; i < DHWN; i = i + TM){ + float x[TM] = *px; + float dy[TM] = *pdy; + dg = dg + dy*(x - mean)*rstd; + db = db + dy; + px = px + TM; + pdy = pdy + TM; + } + float sdg = dg[+]; + float sdb = db[+]; + float *pdg = DG + c; + float *pdb = DB + c; + *pdg = sdg; + *pdb = sdb; + px = X + rx + offset; + pdy = DY + rx + offset; + pdx = DX + rx + offset; + for(int i = 0; i < DHWN; i = i + TM){ + float x[TM] = *px; + float dy[TM] = *pdy; + float xhat[TM] = (x - mean) * rstd; + float xtmp[TM] = (xhat * dg + db) * rcpDHWN; + float dx[TM] = (dy - xtmp) * rstd * g; + *pdx = dx; + px = px + TM; + pdy = pdy + TM; + pdx = pdx + TM; + } +} +""" + + bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB']) + @staticmethod def forward(ctx, x, gamma, beta, eps): shape = triton.shape(x) @@ -63,13 +115,29 @@ void batchnormForward(float *Y, float *M, float *V, var = triton.empty([C], dtype=dtype) # execute kernels N = H*W*B - _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps, + y, mean, var = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps, lambda opt: [1, C], TM = 128) # save ctx.eps = eps ctx.save_for_backward(x, gamma, beta, mean, var) - return y, mean, var + return y + @staticmethod + def backward(ctx, dy): + eps = ctx.eps + x, gamma, beta, mean, var = ctx.saved_tensors + dx = triton.empty(x.shape, dtype=x.dtype) + dgamma = triton.empty(gamma.shape, dtype=gamma.dtype) + dbeta = triton.empty(beta.shape, dtype=beta.dtype) + # launch + C, H, W, B = x.shape + N = H*W*B + _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, + x, gamma, mean, var, + N, 1./N, eps, + lambda opt: [1, C], + TM = 128) + return dx, dgamma, dbeta, None batchnorm = _batchnorm.apply \ No newline at end of file diff --git a/python/triton/utils.py b/python/triton/utils.py index eca9f665e..e55afd602 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -7,17 +7,16 @@ def cdiv(a, b): class tf_empty_proxy: - def __init__(self, args, dtype): - self.args = args + def __init__(self, shape, dtype): + self.shape = shape self.dtype = dtype -def empty(shapes, dtype): +def empty(shape, dtype): if fw.has_tensorflow(): - #return fw.tensorflow.Variable(np.empty(shapes),shape=shapes, dtype=dtype) - args = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shapes] - args = fw.tensorflow.stack(args) - #return tf_empty_proxy(args, dtype) - return fw.tf_extra_ops.alloc_empty(args, T = dtype) + shape = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shape] + shape = fw.tensorflow.stack(shape) + return tf_empty_proxy(shape, dtype) + #return fw.tf_extra_ops.alloc_empty(args, T = dtype) elif fw.has_torch(): return fw.torch.empty(*shapes).cuda()