more stuff

This commit is contained in:
Philippe Tillet
2019-10-30 13:44:31 -04:00
parent bf3dc63858
commit 9b0f1a0807
4 changed files with 91 additions and 106 deletions

View File

@@ -45,11 +45,11 @@ if mode == MODE.TF:
fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype)
fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype)
# execute
fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4)
#fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta])
fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3])
fw_y = triton.ops.batchnorm(fw_x, fw_mean, fw_var, fw_gamma, fw_beta, 1e-4)
fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy)
# run
sess = tf.InteractiveSession()
feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy}
sess.run(tf.global_variables_initializer())
#print(fw_dx, fw_dgamma, fw_dbeta)
result = sess.run([fw_y], feed_dict=feed_dict)
print(result)
result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict)

View File

@@ -192,18 +192,19 @@ void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
const std::string &opname,
const std::vector<ir::argument*>& args,
const std::vector<std::string>& outputs){
auto tolower = [](char c) { return std::tolower(c);};
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
for(size_t i = 0; i < args.size(); i++){
ir::argument *arg = args[i];
std::string name = arg->get_name();
auto tolower = [](char c) { return std::tolower(c);};
std::transform(name.begin(), name.end(), name.begin(), tolower);
if(!arg->get_type()->is_pointer_ty())
os << ".HostMemory(\"" + name + "\")";
}
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
name[0] = std::tolower(name[0]);
std::transform(name.begin(), name.end(), name.begin(), tolower);
os << ".HostMemory(\"" << name << "_shape\")";
}
os << ", " + opname << ");\n";

View File

@@ -45,6 +45,7 @@ class function(metaclass = function_meta):
def apply_tensorflow(cls, *args, **kwargs):
ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs)
op = result[0].op if isinstance(result, tuple) else result.op
# Find a mapping between ::forward arguments and tensorflow op arguments
remap = dict()
for i, ix in enumerate(result.op.inputs):
@@ -52,13 +53,12 @@ class function(metaclass = function_meta):
if ix is jx:
remap[j] = i
# register backward
ctx_registry[result] = ctx
name = result.op.op_def.name
ctx_registry[op] = ctx
name = op.op_def.name
if not cls.registered:
@fw.tensorflow.RegisterGradient(name)
def gradient(op, dy):
y = op.outputs[0]
grad = cls.backward(ctx_registry[y], dy)
def gradient(op, *dys):
grad = cls.backward(ctx_registry[op], dys if len(dys) > 1 else dys[0])
# Remap gradient in the right order
ret = [None] * len(op.inputs)
for i in range(len(grad)):

View File

@@ -6,138 +6,122 @@ class _batchnorm(triton.function):
fwd_src = """
void fwdbatchnorm(float *Y, float *M, float *V,
float *X, float *G, float *B,
int N, float rcpN, float eps) {
int rx[TM] = 0 ... TM;
float *px[TM];
float x[TM] = 0;
int N, float eps) {
// pointers
int c = get_program_id(1);
float g = *(G + c);
float b = *(B + c);
int rm[TM] = 0 ... TM;
float *px[TM] = X + rm + c*N;
float* py[TM] = Y + rm + c*N;
float mean[TM] = 0;
px = X + rx + c*N;
// compute mean
float accm[TM] = 0;
for(int i = 0; i < N; i = i + TM)
accm = accm + *(px + i);
float mean = (float)accm[+] / N;
*(M + c) = mean;
// compute variance
float accv[TM] = 0;
for(int i = 0; i < N; i = i + TM){
x = *px;
mean = mean + x;
px = px + TM;
float x[TM] = *(px + i);
x = x - mean;
accv = accv + x*x;
}
float *pm = M + c;
float m = mean[+] * rcpN;
*pm = m;
float var = (float)accv[+] / N;
*(V + c) = var;
float var[TM] = 0;
px = X + rx + c*N;
// Normalize batch
float gamma = *(G + c);
float beta = *(B + c);
float rstdg = 1 / sqrtf(var + eps) * gamma;
for(int i = 0; i < N; i = i + TM){
x = *px;
x = x - m;
var = var + x*x;
px = px + TM;
}
float v = var[+] * rcpN;
float *pv = V + c;
*pv = v;
float rstdg = 1 / sqrtf(v + eps) * g;
px = X + rx + c*N;
float* py[TM] = Y + rx + c*N;
for(int i = 0; i < N; i = i + TM){
x = *px;
float y[TM] = (x - m)*rstdg + b;
*py = y;
px = px + TM;
py = py + TM;
float x[TM] = *(px + i);
float y[TM] = (x - mean)*rstdg + beta;
*(py + i) = y;
}
}
"""
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
fwd_kernel = triton.kernel(fwd_src, ['Y'])
bwd_src = """
void batchnormBackward(float *DX, float *DG, float *DB,
float *DY, float *X, float *G,
float *M, float *V,
int DHWN, float rcpDHWN, float epsilon) {
int rx[TM] = 0 ... TM;
void bwdbatchnorm(float *DX, float *DG, float *DB,
float *DY, float *X, float *G,
float *M, float *V,
int N, float epsilon) {
// pointers
int c = get_program_id(1);
int offset = c*DHWN;
float g = *(G + c);
int rx[TM] = 0 ... TM;
int offset = c*N;
float* px[TM] = X + rx + offset;
float* pdy[TM] = DY + rx + offset;
float* pdx[TM] = DX + rx + offset;
// fetch statistics
float gamma = *(G + c);
float mean = *(M + c);
float var = *(V + c);
float rstd = 1 / sqrtf(var + epsilon);
float* px[TM];
float* pdx[TM];
float* pdy[TM];
px = X + rx + offset;
pdy = DY + rx + offset;
float dg[TM] = 0;
float db[TM] = 0;
for(int i = 0; i < DHWN; i = i + TM){
float x[TM] = *px;
float dy[TM] = *pdy;
dg = dg + dy*(x - mean)*rstd;
db = db + dy;
px = px + TM;
pdy = pdy + TM;
// compute dgamma and dbeta
float acc_dg[TM] = 0;
float acc_db[TM] = 0;
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
acc_dg += dy*(x - mean)*rstd;
acc_db += dy;
}
float sdg = dg[+];
float sdb = db[+];
float *pdg = DG + c;
float *pdb = DB + c;
*pdg = sdg;
*pdb = sdb;
px = X + rx + offset;
pdy = DY + rx + offset;
pdx = DX + rx + offset;
for(int i = 0; i < DHWN; i = i + TM){
float x[TM] = *px;
float dy[TM] = *pdy;
float dg = acc_dg[+];
float db = acc_db[+];
*(DG + c) = dg;
*(DB + c) = db;
// compute dx
for(int i = 0; i < N; i = i + TM){
float x[TM] = *(px + i);
float dy[TM] = *(pdy + i);
float xhat[TM] = (x - mean) * rstd;
float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
float dx[TM] = (dy - xtmp) * rstd * g;
*pdx = dx;
px = px + TM;
pdy = pdy + TM;
pdx = pdx + TM;
float xtmp[TM] = (xhat * dg + db) / N;
float dx[TM] = (dy - xtmp) * rstd * gamma;
*(pdx + i) = dx;
}
}
"""
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
@staticmethod
def forward(ctx, x, gamma, beta, eps):
def forward(ctx, x, mean, var, gamma, beta, eps):
shape = triton.shape(x)
dtype = x.dtype
# allocate outputs
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
y = triton.empty(shape, dtype=dtype)
mean = triton.empty([C], dtype=dtype)
var = triton.empty([C], dtype=dtype)
# execute kernels
N = H*W*B
y, mean, var = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
lambda opt: [1, C],
TM = 128)
y = _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
lambda opt: [1, C],
TM = 128)
# save
ctx.eps = eps
ctx.save_for_backward(x, gamma, beta, mean, var)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
eps = ctx.eps
# retrieve info
x, gamma, beta, mean, var = ctx.saved_tensors
dx = triton.empty(x.shape, dtype=x.dtype)
dgamma = triton.empty(gamma.shape, dtype=gamma.dtype)
dbeta = triton.empty(beta.shape, dtype=beta.dtype)
# launch
C, H, W, B = x.shape
N = H*W*B
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
eps = ctx.eps
# allocate result
dx = triton.empty(triton.shape(x), dtype=x.dtype)
dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype)
dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype)
# execute
C, H, W, B = triton.shape(x)
dx, dgamma, dbeta = _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
x, gamma, mean, var,
N, 1./N, eps,
H*W*B, eps,
lambda opt: [1, C],
TM = 128)
return dx, dgamma, dbeta, None
return dx, None, None, dgamma, dbeta, None
batchnorm = _batchnorm.apply