[codegen/batchnorm] forward and backward now seemingly working
This commit is contained in:
@@ -100,7 +100,7 @@ def batch_norm_grad(op, dy, mean, var):
|
||||
|
||||
|
||||
def run_batchnorm():
|
||||
C, H, W, B = 32, 16, 16, 16
|
||||
C, H, W, B = 1, 4, 4, 4
|
||||
np.random.seed(0)
|
||||
# Placeholders
|
||||
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
@@ -112,11 +112,19 @@ def run_batchnorm():
|
||||
hb = np.random.rand(C)
|
||||
# batchnorm
|
||||
y, m, v = module.batchnorm_forward(x, g, b, eps=1e-5)
|
||||
loss = np.sum(y)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([y, m, v], feed_dict = {x: hx, g: hg, b: hb})
|
||||
|
||||
|
||||
#print(result[0], result[1], result[2])
|
||||
grads = tf.test.compute_gradient([x, g, b], [(C, H, W, B), (C, ), (C, )], y, (C, H, W, B),
|
||||
extra_feed_dict = {x: hx, g: hg, b: hb})
|
||||
dx_t, dx_n = grads[0]
|
||||
dg_t, dg_n = grads[1]
|
||||
db_t, db_n = grads[2]
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_batchnorm()
|
||||
|
@@ -61,7 +61,7 @@ private:
|
||||
class batchnorm_backward {
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32");
|
||||
batchnorm_backward(int C, int D, int H, int W, int B, std::string ty = "fp32", float eps = 1e-5);
|
||||
// enqueue
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *dx, driver::buffer *dg, driver::buffer *db, driver::buffer *dy,
|
||||
@@ -78,6 +78,7 @@ private:
|
||||
int32_t W_;
|
||||
int32_t B_;
|
||||
std::string ty_;
|
||||
float eps_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -79,13 +79,14 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 *pm = M + c;
|
||||
*pm = __sum(mean) * rcpDHWN;
|
||||
fp32 m = __sum(mean) * rcpDHWN;
|
||||
*pm = m;
|
||||
|
||||
fp32 var[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
x = x - mean;
|
||||
x = x - m;
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
@@ -99,7 +100,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
fp32* py[TM] = Y + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
fp32 y[TM] = (x - mean)*rstdg + b;
|
||||
fp32 y[TM] = (x - m)*rstdg + b;
|
||||
*py = y;
|
||||
px = px + TM;
|
||||
py = py + TM;
|
||||
@@ -111,8 +112,8 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
* Backward
|
||||
* --------------- */
|
||||
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty)
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps)
|
||||
{ }
|
||||
|
||||
void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
@@ -120,7 +121,7 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *x, driver::buffer *g, driver::buffer *m, driver::buffer *v,
|
||||
size_t, size_t nthreads) {
|
||||
|
||||
std::array<size_t, 3> grid = {(size_t)C_, 1, 1};
|
||||
std::array<size_t, 3> grid = {1, (size_t)C_, 1};
|
||||
kernel->setArg(0, dx);
|
||||
kernel->setArg(1, dg);
|
||||
kernel->setArg(2, db);
|
||||
@@ -130,6 +131,8 @@ void batchnorm_backward::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(6, m);
|
||||
kernel->setArg(7, v);
|
||||
kernel->setArg(8, (int32_t)(D_*H_*W_*B_));
|
||||
kernel->setArg(9, (float)1/(D_*H_*W_*B_));
|
||||
kernel->setArg(10, eps_);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
@@ -144,14 +147,14 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *M,
|
||||
restrict read_only fp32 *V,
|
||||
int32 DHWN) {
|
||||
int32 rx[TM] = get_global_range[TM](0);
|
||||
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
|
||||
int32 rx[TM] = 0 ... TM;
|
||||
int32 c = get_range_id(0);
|
||||
int32 offset = c*DHWN;
|
||||
fp32 g = *(G + c);
|
||||
fp32 mean = *(M + c);
|
||||
fp32 var = *(V + c);
|
||||
fp32 rstd = var;
|
||||
fp32 rstd = 1 / sqrt(var + epsilon);
|
||||
fp32* px[TM];
|
||||
fp32* pdx[TM];
|
||||
fp32* pdy[TM];
|
||||
@@ -160,7 +163,7 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
||||
pdy = DY + rx + offset;
|
||||
fp32 dg[TM] = 0;
|
||||
fp32 db[TM] = 0;
|
||||
for(int32 i = 0; i < DHWN; i += TM){
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
@@ -170,15 +173,19 @@ void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
||||
}
|
||||
fp32 sdg = __sum(dg);
|
||||
fp32 sdb = __sum(db);
|
||||
fp32 *pdg = DG + c;
|
||||
fp32 *pdb = DB + c;
|
||||
*pdg = sdg;
|
||||
*pdb = sdb;
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int32 i = 0; i < DHWN; i += TM){
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
fp32 xhat[TM] = (x - mean) * rstd;
|
||||
fp32 xtmp[TM] = (xhat * dg + db) * NDHW;
|
||||
fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
fp32 dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
px = px + TM;
|
||||
|
Reference in New Issue
Block a user