From cc4160478443e2a40045fae27546d610fa584384 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 9 Jul 2019 13:03:16 -0700 Subject: [PATCH] [codegen/batchnorm] forward and backward now seemingly working --- examples/python/tensorflow/run.py | 14 +++++++++++--- include/triton/dnn/batchnorm.h | 3 ++- lib/dnn/batchnorm.cpp | 31 +++++++++++++++++++------------ 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index df37c830c..a0f107ea4 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -100,7 +100,7 @@ def batch_norm_grad(op, dy, mean, var): def run_batchnorm(): - C, H, W, B = 32, 16, 16, 16 + C, H, W, B = 1, 4, 4, 4 np.random.seed(0) # Placeholders x = tf.placeholder(tf.float32, shape=[C, H, W, B]) @@ -112,11 +112,19 @@ def run_batchnorm(): hb = np.random.rand(C) # batchnorm y, m, v = module.batchnorm_forward(x, g, b, eps=1e-5) + loss = np.sum(y) # Run sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb}) - - + #print(result[0], result[1], result[2]) + grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B), + extra_feed_dict = {x: hx, g: hg, b: hb}) + dx_t, dx_n = grads[0] + dg_t, dg_n = grads[1] + db_t, db_n = grads[2] + print(np.max(np.abs(dx_t - dx_n))) + print(np.max(np.abs(dg_t - dg_n))) + print(np.max(np.abs(db_t - db_n))) run_batchnorm() diff --git a/include/triton/dnn/batchnorm.h b/include/triton/dnn/batchnorm.h index 7a97e83af..65f71ce58 100644 --- a/include/triton/dnn/batchnorm.h +++ b/include/triton/dnn/batchnorm.h @@ -61,7 +61,7 @@ private: class batchnorm_backward { public: // constructor - batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32"); + batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32", float eps = 1e-5); // enqueue void enqueue(driver::stream *stream, driver::kernel *kernel, driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy, @@ -78,6 +78,7 @@ private: int32_t W_; int32_t B_; std::string ty_; + float eps_; }; } diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp index e3b1a630c..a8e91bf8e 100644 --- a/lib/dnn/batchnorm.cpp +++ b/lib/dnn/batchnorm.cpp @@ -79,13 +79,14 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, px = px + TM; } fp32 *pm = M + c; - *pm = __sum(mean) * rcpDHWN; + fp32 m = __sum(mean) * rcpDHWN; + *pm = m; fp32 var[TM] = 0; px = X + rx + c*DHWN; for(int32 i = 0; i < DHWN; i = i + TM){ x = *px; - x = x - mean; + x = x - m; var = var + x*x; px = px + TM; } @@ -99,7 +100,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, fp32* py[TM] = Y + rx + c*DHWN; for(int32 i = 0; i < DHWN; i = i + TM){ x = *px; - fp32 y[TM] = (x - mean)*rstdg + b; + fp32 y[TM] = (x - m)*rstdg + b; *py = y; px = px + TM; py = py + TM; @@ -111,8 +112,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, * Backward * --------------- */ -batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty) - : C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty) +batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps) + : C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) { } void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel, @@ -120,7 +121,7 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel, driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v, size_t, size_t nthreads) { - std::array grid = {(size_t)C_, 1, 1}; + std::array grid = {1, (size_t)C_, 1}; kernel->setArg(0, dx); kernel->setArg(1, dg); kernel->setArg(2, db); @@ -130,6 +131,8 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel, kernel->setArg(6, m); kernel->setArg(7, v); kernel->setArg(8, (int32_t)(D_*H_*W_*B_)); + kernel->setArg(9, (float)1/(D_*H_*W_*B_)); + kernel->setArg(10, eps_); stream->enqueue(kernel, grid, {nthreads, 1, 1}); } @@ -144,14 +147,14 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB, restrict read_only fp32 *G, restrict read_only fp32 *M, restrict read_only fp32 *V, - int32 DHWN) { - int32 rx[TM] = get_global_range[TM](0); + int32 DHWN, fp32 rcpDHWN, fp32 epsilon) { + int32 rx[TM] = 0 ... TM; int32 c = get_range_id(0); int32 offset = c*DHWN; fp32 g = *(G + c); fp32 mean = *(M + c); fp32 var = *(V + c); - fp32 rstd = var; + fp32 rstd = 1 / sqrt(var + epsilon); fp32* px[TM]; fp32* pdx[TM]; fp32* pdy[TM]; @@ -160,7 +163,7 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB, pdy = DY + rx + offset; fp32 dg[TM] = 0; fp32 db[TM] = 0; - for(int32 i = 0; i < DHWN; i += TM){ + for(int32 i = 0; i < DHWN; i = i + TM){ fp32 x[TM] = *px; fp32 dy[TM] = *pdy; dg = dg + dy*(x - mean)*rstd; @@ -170,15 +173,19 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB, } fp32 sdg = __sum(dg); fp32 sdb = __sum(db); + fp32 *pdg = DG + c; + fp32 *pdb = DB + c; + *pdg = sdg; + *pdb = sdb; px = X + rx + offset; pdy = DY + rx + offset; pdx = DX + rx + offset; - for(int32 i = 0; i < DHWN; i += TM){ + for(int32 i = 0; i < DHWN; i = i + TM){ fp32 x[TM] = *px; fp32 dy[TM] = *pdy; fp32 xhat[TM] = (x - mean) * rstd; - fp32 xtmp[TM] = (xhat * dg + db) * NDHW; + fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN; fp32 dx[TM] = (dy - xtmp) * rstd * g; *pdx = dx; px = px + TM;