[PYTHON][OPS] Added batch normalization op

This commit is contained in:
Philippe Tillet
2019-10-29 17:29:11 -04:00
parent d9eacf937c
commit d65a94c768
9 changed files with 108 additions and 34 deletions

View File

@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
pass
cfg = 'Debug' if self.debug else 'Release'
cfg = 'Release'
#cfg = 'Release'
build_args = ['--config', cfg]
if platform.system() == "Windows":

View File

@@ -211,27 +211,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
os << ";\n";
}
inline std::string preheader() {
return
R"(
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
#define __readonly __attribute__((readonly))
#define __writeonly __attribute__((writeonly))
#define __noalias __attribute__((noalias))
#define __aligned(A) __attribute__((aligned(A)))
#define __multipleof(A) __attribute__((multipleof(A)))
extern int get_program_id(int);
)";
}
void make_module(const std::string& src, ir::module* ir,
const runtime::function::options_space_t& opt) {
std::string copy = preheader() + src;
std::string copy = triton::runtime::function::preheader() + src;
// pre-process
TokenSequence tokens;
Preprocessor cpp(&copy, true);
@@ -341,11 +323,11 @@ inline std::string to_torch_ty(ir::type *ty) {
if(ty->is_integer_ty())
return "int64_t";
if(ty->is_half_ty())
return "float16";
return "double";
if(ty->is_float_ty())
return "float32";
return "double";
if(ty->is_double_ty())
return "float64";
return "double";
if(ty->is_pointer_ty())
return "torch::Tensor";
throw std::runtime_error("unknown type");
@@ -363,11 +345,11 @@ inline std::string to_c_ty(ir::type *ty) {
if(ty->is_integer_ty(64))
return "int64_t";
if(ty->is_half_ty())
return "float16";
return "half";
if(ty->is_float_ty())
return "float32";
return "float";
if(ty->is_double_ty())
return "float64";
return "double";
if(ty->is_pointer_ty())
return "drv::cu_buffer";
throw std::runtime_error("unknown type");

View File

@@ -1,2 +1,3 @@
from .dot import _dot, dot
from .einsum import _einsum, einsum
from .batchnorm import _batchnorm, batchnorm

View File

@@ -0,0 +1,75 @@
import triton
import math
class _batchnorm(triton.function):
fwd_src = """
void batchnormForward(float *Y, float *M, float *V,
float *X, float *G, float *B,
int N, float rcpN, float eps) {
int rx[TM] = 0 ... TM;
float *px[TM];
float x[TM] = 0;
int c = get_program_id(1);
float g = *(G + c);
float b = *(B + c);
float mean[TM] = 0;
px = X + rx + c*N;
for(int i = 0; i < N; i = i + TM){
x = *px;
mean = mean + x;
px = px + TM;
}
float *pm = M + c;
float m = mean[+] * rcpN;
*pm = m;
float var[TM] = 0;
px = X + rx + c*N;
for(int i = 0; i < N; i = i + TM){
x = *px;
x = x - m;
var = var + x*x;
px = px + TM;
}
float v = var[+] * rcpN;
float *pv = V + c;
*pv = v;
float rstdg = 1 / sqrtf(v + eps) * g;
px = X + rx + c*N;
float* py[TM] = Y + rx + c*N;
for(int i = 0; i < N; i = i + TM){
x = *px;
float y[TM] = (x - m)*rstdg + b;
*py = y;
px = px + TM;
py = py + TM;
}
}
"""
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
@staticmethod
def forward(ctx, x, gamma, beta, eps):
shape = triton.shape(x)
dtype = x.dtype
# allocate outputs
C, H, W, B = shape[0], shape[1], shape[2], shape[3]
y = triton.empty(shape, dtype=dtype)
mean = triton.empty([C], dtype=dtype)
var = triton.empty([C], dtype=dtype)
# execute kernels
N = H*W*B
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, N, 1./N, eps,
lambda opt: [1, C],
TM = 128)
# save
ctx.eps = eps
ctx.save_for_backward(x, gamma, beta, mean, var)
return y, mean, var
batchnorm = _batchnorm.apply

View File

@@ -181,15 +181,16 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
@staticmethod
def forward(ctx, subscripts, a, b, bench = 0):
ctx.save_for_backward(a, b)
# parse
if type(subscripts) is str:
einsum_a, einsum_bc = subscripts.split(",")
einsum_b, einsum_c = einsum_bc.split("->")
else:
einsum_a, einsum_b, einsum_c = subscripts
shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum(
einsum_a, einsum_b, einsum_c,
triton.shape(a), triton.shape(b))
# save for backward
ctx.trans_a = ta
ctx.trans_b = tb
ctx.einsum_a = einsum_a
@@ -197,6 +198,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
ctx.einsum_c = einsum_c
ctx.bench = bench
ctx.bmnk = bmnk
# run
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)