more stuff
This commit is contained in:
@@ -45,11 +45,11 @@ if mode == MODE.TF:
|
||||
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
|
||||
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
|
||||
# execute
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
|
||||
#fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta])
|
||||
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
|
||||
fw_y = triton.ops.batchnorm(fw_x, fw_mean, fw_var, fw_gamma, fw_beta, 1e-4)
|
||||
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
|
||||
# run
|
||||
sess = tf.InteractiveSession()
|
||||
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
|
||||
sess.run(tf.global_variables_initializer())
|
||||
#print(fw_dx, fw_dgamma, fw_dbeta)
|
||||
result = sess.run([fw_y], feed_dict=feed_dict)
|
||||
print(result)
|
||||
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)
|
@@ -192,18 +192,19 @@ void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
const std::string &opname,
|
||||
const std::vector<ir::argument*>& args,
|
||||
const std::vector<std::string>& outputs){
|
||||
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
ir::argument *arg = args[i];
|
||||
std::string name = arg->get_name();
|
||||
auto tolower = [](char c) { return std::tolower(c);};
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
if(!arg->get_type()->is_pointer_ty())
|
||||
os << ".HostMemory(\"" + name + "\")";
|
||||
}
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
name[0] = std::tolower(name[0]);
|
||||
std::transform(name.begin(), name.end(), name.begin(), tolower);
|
||||
os << ".HostMemory(\"" << name << "_shape\")";
|
||||
}
|
||||
os << ", " + opname << ");\n";
|
||||
|
@@ -45,6 +45,7 @@ class function(metaclass = function_meta):
|
||||
def apply_tensorflow(cls, *args, **kwargs):
|
||||
ctx = OpContext()
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
op = result[0].op if isinstance(result, tuple) else result.op
|
||||
# Find a mapping between ::forward arguments and tensorflow op arguments
|
||||
remap = dict()
|
||||
for i, ix in enumerate(result.op.inputs):
|
||||
@@ -52,13 +53,12 @@ class function(metaclass = function_meta):
|
||||
if ix is jx:
|
||||
remap[j] = i
|
||||
# register backward
|
||||
ctx_registry[result] = ctx
|
||||
name = result.op.op_def.name
|
||||
ctx_registry[op] = ctx
|
||||
name = op.op_def.name
|
||||
if not cls.registered:
|
||||
@fw.tensorflow.RegisterGradient(name)
|
||||
def gradient(op, dy):
|
||||
y = op.outputs[0]
|
||||
grad = cls.backward(ctx_registry[y], dy)
|
||||
def gradient(op, *dys):
|
||||
grad = cls.backward(ctx_registry[op], dys if len(dys) > 1 else dys[0])
|
||||
# Remap gradient in the right order
|
||||
ret = [None] * len(op.inputs)
|
||||
for i in range(len(grad)):
|
||||
|
@@ -6,138 +6,122 @@ class _batchnorm(triton.function):
|
||||
fwd_src = """
|
||||
void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
float *X, float *G, float *B,
|
||||
int N, float rcpN, float eps) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
float *px[TM];
|
||||
float x[TM] = 0;
|
||||
int N, float eps) {
|
||||
// pointers
|
||||
int c = get_program_id(1);
|
||||
float g = *(G + c);
|
||||
float b = *(B + c);
|
||||
int rm[TM] = 0 ... TM;
|
||||
float *px[TM] = X + rm + c*N;
|
||||
float* py[TM] = Y + rm + c*N;
|
||||
|
||||
float mean[TM] = 0;
|
||||
px = X + rx + c*N;
|
||||
// compute mean
|
||||
float accm[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM)
|
||||
accm = accm + *(px + i);
|
||||
float mean = (float)accm[+] / N;
|
||||
*(M + c) = mean;
|
||||
|
||||
// compute variance
|
||||
float accv[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
x = *px;
|
||||
mean = mean + x;
|
||||
px = px + TM;
|
||||
float x[TM] = *(px + i);
|
||||
x = x - mean;
|
||||
accv = accv + x*x;
|
||||
}
|
||||
float *pm = M + c;
|
||||
float m = mean[+] * rcpN;
|
||||
*pm = m;
|
||||
float var = (float)accv[+] / N;
|
||||
*(V + c) = var;
|
||||
|
||||
float var[TM] = 0;
|
||||
px = X + rx + c*N;
|
||||
// Normalize batch
|
||||
float gamma = *(G + c);
|
||||
float beta = *(B + c);
|
||||
float rstdg = 1 / sqrtf(var + eps) * gamma;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
x = *px;
|
||||
x = x - m;
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
float v = var[+] * rcpN;
|
||||
float *pv = V + c;
|
||||
*pv = v;
|
||||
float rstdg = 1 / sqrtf(v + eps) * g;
|
||||
|
||||
px = X + rx + c*N;
|
||||
float* py[TM] = Y + rx + c*N;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
x = *px;
|
||||
float y[TM] = (x - m)*rstdg + b;
|
||||
*py = y;
|
||||
px = px + TM;
|
||||
py = py + TM;
|
||||
float x[TM] = *(px + i);
|
||||
float y[TM] = (x - mean)*rstdg + beta;
|
||||
*(py + i) = y;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
|
||||
fwd_kernel = triton.kernel(fwd_src, ['Y'])
|
||||
|
||||
bwd_src = """
|
||||
void batchnormBackward(float *DX, float *DG, float *DB,
|
||||
float *DY, float *X, float *G,
|
||||
float *M, float *V,
|
||||
int DHWN, float rcpDHWN, float epsilon) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
float *DY, float *X, float *G,
|
||||
float *M, float *V,
|
||||
int N, float epsilon) {
|
||||
|
||||
// pointers
|
||||
int c = get_program_id(1);
|
||||
int offset = c*DHWN;
|
||||
float g = *(G + c);
|
||||
int rx[TM] = 0 ... TM;
|
||||
int offset = c*N;
|
||||
float* px[TM] = X + rx + offset;
|
||||
float* pdy[TM] = DY + rx + offset;
|
||||
float* pdx[TM] = DX + rx + offset;
|
||||
|
||||
// fetch statistics
|
||||
float gamma = *(G + c);
|
||||
float mean = *(M + c);
|
||||
float var = *(V + c);
|
||||
float rstd = 1 / sqrtf(var + epsilon);
|
||||
float* px[TM];
|
||||
float* pdx[TM];
|
||||
float* pdy[TM];
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
float dg[TM] = 0;
|
||||
float db[TM] = 0;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
db = db + dy;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
|
||||
// compute dgamma and dbeta
|
||||
float acc_dg[TM] = 0;
|
||||
float acc_db[TM] = 0;
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
float x[TM] = *(px + i);
|
||||
float dy[TM] = *(pdy + i);
|
||||
acc_dg += dy*(x - mean)*rstd;
|
||||
acc_db += dy;
|
||||
}
|
||||
float sdg = dg[+];
|
||||
float sdb = db[+];
|
||||
float *pdg = DG + c;
|
||||
float *pdb = DB + c;
|
||||
*pdg = sdg;
|
||||
*pdb = sdb;
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
float dg = acc_dg[+];
|
||||
float db = acc_db[+];
|
||||
*(DG + c) = dg;
|
||||
*(DB + c) = db;
|
||||
|
||||
// compute dx
|
||||
for(int i = 0; i < N; i = i + TM){
|
||||
float x[TM] = *(px + i);
|
||||
float dy[TM] = *(pdy + i);
|
||||
float xhat[TM] = (x - mean) * rstd;
|
||||
float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
float dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
pdx = pdx + TM;
|
||||
float xtmp[TM] = (xhat * dg + db) / N;
|
||||
float dx[TM] = (dy - xtmp) * rstd * gamma;
|
||||
*(pdx + i) = dx;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
def forward(ctx, x, mean, var, gamma, beta, eps):
|
||||
shape = triton.shape(x)
|
||||
dtype = x.dtype
|
||||
# allocate outputs
|
||||
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
|
||||
y = triton.empty(shape, dtype=dtype)
|
||||
mean = triton.empty([C], dtype=dtype)
|
||||
var = triton.empty([C], dtype=dtype)
|
||||
# execute kernels
|
||||
N = H*W*B
|
||||
y, mean, var = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
y = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
# save
|
||||
ctx.eps = eps
|
||||
ctx.save_for_backward(x, gamma, beta, mean, var)
|
||||
ctx.eps = eps
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
eps = ctx.eps
|
||||
# retrieve info
|
||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||
dx = triton.empty(x.shape, dtype=x.dtype)
|
||||
dgamma = triton.empty(gamma.shape, dtype=gamma.dtype)
|
||||
dbeta = triton.empty(beta.shape, dtype=beta.dtype)
|
||||
# launch
|
||||
C, H, W, B = x.shape
|
||||
N = H*W*B
|
||||
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
||||
eps = ctx.eps
|
||||
# allocate result
|
||||
dx = triton.empty(triton.shape(x), dtype=x.dtype)
|
||||
dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
|
||||
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
|
||||
# execute
|
||||
C, H, W, B = triton.shape(x)
|
||||
dx, dgamma, dbeta = _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
|
||||
x, gamma, mean, var,
|
||||
N, 1./N, eps,
|
||||
H*W*B, eps,
|
||||
lambda opt: [1, C],
|
||||
TM = 128)
|
||||
return dx, dgamma, dbeta, None
|
||||
return dx, None, None, dgamma, dbeta, None
|
||||
|
||||
batchnorm = _batchnorm.apply
|
Reference in New Issue
Block a user