[codegen/batchnorm] forward and backward now seemingly working

This commit is contained in:
Philippe Tillet
2019-07-09 13:03:16 -07:00
parent f74dcb7e30
commit cc41604784
3 changed files with 32 additions and 16 deletions

View File

@@ -100,7 +100,7 @@ def batch_norm_grad(op, dy, mean, var):
def run_batchnorm():
C, H, W, B = 32, 16, 16, 16
C, H, W, B = 1, 4, 4, 4
np.random.seed(0)
# Placeholders
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
@@ -112,11 +112,19 @@ def run_batchnorm():
hb = np.random.rand(C)
# batchnorm
y, m, v = module.batchnorm_forward(x, g, b, eps=1e-5)
loss = np.sum(y)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
#print(result[0], result[1], result[2])
grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B),
extra_feed_dict = {x: hx, g: hg, b: hb})
dx_t, dx_n = grads[0]
dg_t, dg_n = grads[1]
db_t, db_n = grads[2]
print(np.max(np.abs(dx_t - dx_n)))
print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n)))
run_batchnorm()

View File

@@ -61,7 +61,7 @@ private:
class batchnorm_backward {
public:
// constructor
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32");
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32", float eps = 1e-5);
// enqueue
void enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
@@ -78,6 +78,7 @@ private:
int32_t W_;
int32_t B_;
std::string ty_;
float eps_;
};
}

View File

@@ -79,13 +79,14 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
px = px + TM;
}
fp32 *pm = M + c;
*pm = __sum(mean) * rcpDHWN;
fp32 m = __sum(mean) * rcpDHWN;
*pm = m;
fp32 var[TM] = 0;
px = X + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){
x = *px;
x = x - mean;
x = x - m;
var = var + x*x;
px = px + TM;
}
@@ -99,7 +100,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
fp32* py[TM] = Y + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){
x = *px;
fp32 y[TM] = (x - mean)*rstdg + b;
fp32 y[TM] = (x - m)*rstdg + b;
*py = y;
px = px + TM;
py = py + TM;
@@ -111,8 +112,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
* Backward
* --------------- */
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty)
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty)
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps)
{ }
void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
@@ -120,7 +121,7 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
size_t, size_t nthreads) {
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
kernel->setArg(0, dx);
kernel->setArg(1, dg);
kernel->setArg(2, db);
@@ -130,6 +131,8 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(6, m);
kernel->setArg(7, v);
kernel->setArg(8, (int32_t)(D_*H_*W_*B_));
kernel->setArg(9, (float)1/(D_*H_*W_*B_));
kernel->setArg(10, eps_);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
}
@@ -144,14 +147,14 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
restrict read_only fp32 *G,
restrict read_only fp32 *M,
restrict read_only fp32 *V,
int32 DHWN) {
int32 rx[TM] = get_global_range[TM](0);
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
int32 rx[TM] = 0 ... TM;
int32 c = get_range_id(0);
int32 offset = c*DHWN;
fp32 g = *(G + c);
fp32 mean = *(M + c);
fp32 var = *(V + c);
fp32 rstd = var;
fp32 rstd = 1 / sqrt(var + epsilon);
fp32* px[TM];
fp32* pdx[TM];
fp32* pdy[TM];
@@ -160,7 +163,7 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
pdy = DY + rx + offset;
fp32 dg[TM] = 0;
fp32 db[TM] = 0;
for(int32 i = 0; i < DHWN; i += TM){
for(int32 i = 0; i < DHWN; i = i + TM){
fp32 x[TM] = *px;
fp32 dy[TM] = *pdy;
dg = dg + dy*(x - mean)*rstd;
@@ -170,15 +173,19 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
}
fp32 sdg = __sum(dg);
fp32 sdb = __sum(db);
fp32 *pdg = DG + c;
fp32 *pdb = DB + c;
*pdg = sdg;
*pdb = sdb;
px = X + rx + offset;
pdy = DY + rx + offset;
pdx = DX + rx + offset;
for(int32 i = 0; i < DHWN; i += TM){
for(int32 i = 0; i < DHWN; i = i + TM){
fp32 x[TM] = *px;
fp32 dy[TM] = *pdy;
fp32 xhat[TM] = (x - mean) * rstd;
fp32 xtmp[TM] = (xhat * dg + db) * NDHW;
fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN;
fp32 dx[TM] = (dy - xtmp) * rstd * g;
*pdx = dx;
px = px + TM;