diff --git a/python/examples/batchnorm.py b/python/examples/batchnorm.py index 673c8ec2c..e28fd5039 100644 --- a/python/examples/batchnorm.py +++ b/python/examples/batchnorm.py @@ -45,11 +45,11 @@ if mode == MODE.TF: fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype) fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype) # execute - fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4) - #fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta]) + fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3]) + fw_y = triton.ops.batchnorm(fw_x, fw_mean, fw_var, fw_gamma, fw_beta, 1e-4) + fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy) + # run sess = tf.InteractiveSession() feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy} sess.run(tf.global_variables_initializer()) - #print(fw_dx, fw_dgamma, fw_dbeta) - result = sess.run([fw_y], feed_dict=feed_dict) - print(result) + result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict) \ No newline at end of file diff --git a/python/src/bindings.cc b/python/src/bindings.cc index b3b74b37b..15386ecc1 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -192,18 +192,19 @@ void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name, const std::string &opname, const std::vector& args, const std::vector& outputs){ + + auto tolower = [](char c) { return std::tolower(c);}; os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)"; for(size_t i = 0; i < args.size(); i++){ ir::argument *arg = args[i]; std::string name = arg->get_name(); - auto tolower = [](char c) { return std::tolower(c);}; std::transform(name.begin(), name.end(), name.begin(), tolower); if(!arg->get_type()->is_pointer_ty()) os << ".HostMemory(\"" + name + "\")"; } for(size_t i = 0; i < outputs.size(); i++){ std::string name = outputs[i]; - name[0] = std::tolower(name[0]); + std::transform(name.begin(), name.end(), name.begin(), tolower); os << ".HostMemory(\"" << name << "_shape\")"; } os << ", " + opname << ");\n"; diff --git a/python/triton/function.py b/python/triton/function.py index e75512b1b..7eba7f9a7 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -45,6 +45,7 @@ class function(metaclass = function_meta): def apply_tensorflow(cls, *args, **kwargs): ctx = OpContext() result = cls.forward(ctx, *args, **kwargs) + op = result[0].op if isinstance(result, tuple) else result.op # Find a mapping between ::forward arguments and tensorflow op arguments remap = dict() for i, ix in enumerate(result.op.inputs): @@ -52,13 +53,12 @@ class function(metaclass = function_meta): if ix is jx: remap[j] = i # register backward - ctx_registry[result] = ctx - name = result.op.op_def.name + ctx_registry[op] = ctx + name = op.op_def.name if not cls.registered: @fw.tensorflow.RegisterGradient(name) - def gradient(op, dy): - y = op.outputs[0] - grad = cls.backward(ctx_registry[y], dy) + def gradient(op, *dys): + grad = cls.backward(ctx_registry[op], dys if len(dys) > 1 else dys[0]) # Remap gradient in the right order ret = [None] * len(op.inputs) for i in range(len(grad)): diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index 5e352d93a..fb6d94017 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -6,138 +6,122 @@ class _batchnorm(triton.function): fwd_src = """ void fwdbatchnorm(float *Y, float *M, float *V, float *X, float *G, float *B, - int N, float rcpN, float eps) { - int rx[TM] = 0 ... TM; - float *px[TM]; - float x[TM] = 0; + int N, float eps) { + // pointers int c = get_program_id(1); - float g = *(G + c); - float b = *(B + c); + int rm[TM] = 0 ... TM; + float *px[TM] = X + rm + c*N; + float* py[TM] = Y + rm + c*N; - float mean[TM] = 0; - px = X + rx + c*N; + // compute mean + float accm[TM] = 0; + for(int i = 0; i < N; i = i + TM) + accm = accm + *(px + i); + float mean = (float)accm[+] / N; + *(M + c) = mean; + + // compute variance + float accv[TM] = 0; for(int i = 0; i < N; i = i + TM){ - x = *px; - mean = mean + x; - px = px + TM; + float x[TM] = *(px + i); + x = x - mean; + accv = accv + x*x; } - float *pm = M + c; - float m = mean[+] * rcpN; - *pm = m; + float var = (float)accv[+] / N; + *(V + c) = var; - float var[TM] = 0; - px = X + rx + c*N; + // Normalize batch + float gamma = *(G + c); + float beta = *(B + c); + float rstdg = 1 / sqrtf(var + eps) * gamma; for(int i = 0; i < N; i = i + TM){ - x = *px; - x = x - m; - var = var + x*x; - px = px + TM; - } - float v = var[+] * rcpN; - float *pv = V + c; - *pv = v; - float rstdg = 1 / sqrtf(v + eps) * g; - - px = X + rx + c*N; - float* py[TM] = Y + rx + c*N; - for(int i = 0; i < N; i = i + TM){ - x = *px; - float y[TM] = (x - m)*rstdg + b; - *py = y; - px = px + TM; - py = py + TM; + float x[TM] = *(px + i); + float y[TM] = (x - mean)*rstdg + beta; + *(py + i) = y; } } """ - - fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V']) + fwd_kernel = triton.kernel(fwd_src, ['Y']) bwd_src = """ -void batchnormBackward(float *DX, float *DG, float *DB, - float *DY, float *X, float *G, - float *M, float *V, - int DHWN, float rcpDHWN, float epsilon) { - int rx[TM] = 0 ... TM; +void bwdbatchnorm(float *DX, float *DG, float *DB, + float *DY, float *X, float *G, + float *M, float *V, + int N, float epsilon) { + + // pointers int c = get_program_id(1); - int offset = c*DHWN; - float g = *(G + c); + int rx[TM] = 0 ... TM; + int offset = c*N; + float* px[TM] = X + rx + offset; + float* pdy[TM] = DY + rx + offset; + float* pdx[TM] = DX + rx + offset; + + // fetch statistics + float gamma = *(G + c); float mean = *(M + c); float var = *(V + c); float rstd = 1 / sqrtf(var + epsilon); - float* px[TM]; - float* pdx[TM]; - float* pdy[TM]; - px = X + rx + offset; - pdy = DY + rx + offset; - float dg[TM] = 0; - float db[TM] = 0; - for(int i = 0; i < DHWN; i = i + TM){ - float x[TM] = *px; - float dy[TM] = *pdy; - dg = dg + dy*(x - mean)*rstd; - db = db + dy; - px = px + TM; - pdy = pdy + TM; + + // compute dgamma and dbeta + float acc_dg[TM] = 0; + float acc_db[TM] = 0; + for(int i = 0; i < N; i = i + TM){ + float x[TM] = *(px + i); + float dy[TM] = *(pdy + i); + acc_dg += dy*(x - mean)*rstd; + acc_db += dy; } - float sdg = dg[+]; - float sdb = db[+]; - float *pdg = DG + c; - float *pdb = DB + c; - *pdg = sdg; - *pdb = sdb; - px = X + rx + offset; - pdy = DY + rx + offset; - pdx = DX + rx + offset; - for(int i = 0; i < DHWN; i = i + TM){ - float x[TM] = *px; - float dy[TM] = *pdy; + float dg = acc_dg[+]; + float db = acc_db[+]; + *(DG + c) = dg; + *(DB + c) = db; + + // compute dx + for(int i = 0; i < N; i = i + TM){ + float x[TM] = *(px + i); + float dy[TM] = *(pdy + i); float xhat[TM] = (x - mean) * rstd; - float xtmp[TM] = (xhat * dg + db) * rcpDHWN; - float dx[TM] = (dy - xtmp) * rstd * g; - *pdx = dx; - px = px + TM; - pdy = pdy + TM; - pdx = pdx + TM; + float xtmp[TM] = (xhat * dg + db) / N; + float dx[TM] = (dy - xtmp) * rstd * gamma; + *(pdx + i) = dx; } } """ - bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB']) @staticmethod - def forward(ctx, x, gamma, beta, eps): + def forward(ctx, x, mean, var, gamma, beta, eps): shape = triton.shape(x) dtype = x.dtype # allocate outputs C, H, W, B = shape[0], shape[1], shape[2], shape[3] y = triton.empty(shape, dtype=dtype) - mean = triton.empty([C], dtype=dtype) - var = triton.empty([C], dtype=dtype) # execute kernels - N = H*W*B - y, mean, var = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps, - lambda opt: [1, C], - TM = 128) + y = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps, + lambda opt: [1, C], + TM = 128) # save - ctx.eps = eps ctx.save_for_backward(x, gamma, beta, mean, var) + ctx.eps = eps return y @staticmethod def backward(ctx, dy): - eps = ctx.eps + # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors - dx = triton.empty(x.shape, dtype=x.dtype) - dgamma = triton.empty(gamma.shape, dtype=gamma.dtype) - dbeta = triton.empty(beta.shape, dtype=beta.dtype) - # launch - C, H, W, B = x.shape - N = H*W*B - _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, + eps = ctx.eps + # allocate result + dx = triton.empty(triton.shape(x), dtype=x.dtype) + dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) + dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) + # execute + C, H, W, B = triton.shape(x) + dx, dgamma, dbeta = _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, x, gamma, mean, var, - N, 1./N, eps, + H*W*B, eps, lambda opt: [1, C], TM = 128) - return dx, dgamma, dbeta, None + return dx, None, None, dgamma, dbeta, None batchnorm = _batchnorm.apply \ No newline at end of file