[TRITON][PYTHON] Cleaned up API

This commit is contained in:
Philippe Tillet
2020-01-31 14:04:42 -05:00
parent 2fcf5cec5b
commit 3e92901bd5
7 changed files with 41 additions and 467 deletions

View File

@@ -11,9 +11,9 @@ configs = []
# Matrix multiplication
MNK = [
# (512, 512 ,512),
# (2048, 2048, 2048),
# (8192, 8192, 8192),
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
# (64, 64, 64000),
# (64, 64, 128000),

View File

@@ -177,12 +177,12 @@ def _make_framework_op(src, outputs, tmp, options):
else:
assert False
def _make_grid(args) :
scalars = [x for x in args[:-1] if isinstance(x, triton.utils.scalar)]
def _make_grid(grid, args) :
scalars = [x for x in args if isinstance(x, triton.utils.scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
result = grid(opt)
for x in scalars:
x.unset_assume_initialized()
return result
@@ -206,24 +206,37 @@ class kernel:
self.cst[name] = value
def __call__(self, *args, **kwargs):
########################
# keyword arguments
########################
num_warps = kwargs['num_warps'] if 'num_warps' in kwargs else [2, 4, 8]
defines = kwargs['defines'] if 'defines' in kwargs else dict()
bench = kwargs['bench'] if 'bench' in kwargs else 0
if 'grid' not in kwargs:
raise RuntimeError('Must provide grid for kernel launch')
grid = kwargs['grid']
#########################
# cache
########################
# create a new framework op when defines are different
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in defines.items()])
if key not in self.fw_id.keys():
# code generation options
defines = []
for k, v in kwargs.items():
macros = []
for k, v in defines.items():
cvt = lambda x: _cvt_to_def_str(x)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
values = [cvt(v)]
defines.append((k, values))
macros.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [4]
opt.defines = macros
opt.num_warps = [2, 4, 8]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
@@ -238,22 +251,21 @@ class kernel:
# initialize
########################
op_id = self.fw_id[key]
libtriton.register_grid(op_id, args[-1])
bench = kwargs['bench'] if 'bench' in kwargs else 0
libtriton.register_grid(op_id, grid)
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
#########################
# call framework function
#########################
if fw.has_tensorflow():
empty = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
empty = [x for x in args if isinstance(x, triton.utils.tf_empty_proxy)]
if len(empty) != len(self.outputs):
raise ValueError('Number of empty arguments does not much number of outputs provided')
# operands
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]]
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args]
# output data types
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
for i, x in enumerate(args[:-1]):
for i, x in enumerate(args):
if isinstance(x, triton.utils.tf_empty_proxy):
kwargs['T' + str(i)] = x.dtype
# launch
@@ -277,7 +289,7 @@ class kernel:
# call torch function
############################
elif fw.has_torch():
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args]
ret = self.fw_op(op_id, bench, bench_id, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(bench_id)

View File

@@ -1,4 +1,2 @@
from .dot import _dot, dot
from .einsum import _einsum, einsum
from .batchnorm import _batchnorm, batchnorm
from .conv import _conv, conv
from .batchnorm import _batchnorm, batchnorm

View File

@@ -101,8 +101,8 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
var = triton.empty([C], dtype=dtype)
# execute kernels
_batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps,
lambda opt: [1, C],
TM = 128)
grid = lambda opt: [1, C],
defines = {'TM': 128})
# save
ctx.save_for_backward(x, gamma, beta, mean, var)
ctx.eps = eps
@@ -122,8 +122,8 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
_batchnorm.bwd_kernel(dx, dgamma, dbeta, dy,
x, gamma, mean, var,
H*W*B, eps,
lambda opt: [1, C],
TM = 128)
grid = lambda opt: [1, C],
defines = {'TM': 128})
return dx, dgamma, dbeta, None
batchnorm = _batchnorm.apply

View File

@@ -1,309 +0,0 @@
import triton
import numpy as np
class _conv(triton.function):
src = """
void convnd(A_TYPE *A,
B_TYPE *B,
float *C,
int M, int N, int K,
int AH, int AW,
int BH, int BW,
int CH, int CW,
int NC,
int lda_n, int lda_c, int lda_d, int lda_h, int lda_w,
int ldb_c, int ldb_t, int ldb_r, int ldb_s, int ldb_k,
int ldc_n, int ldc_k, int ldc_m, int ldc_p, int ldc_q,
int pad_h, int pad_w,
int stride_h, int stride_w,
int upsample_h, int upsample_w,
int off_uh, int off_uw,
int off_uah, int off_uaw,
int off_uch, int off_ucw,
int* ADELTA, int* ADIFF){
// range of indices along the reduction axis
int rxa[TM] = get_program_id(0) * TM + 0 ... TM;
int ryb[TN] = get_program_id(1) * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// initialize accumulator
float c[TM, TN] = 0;
// pointers for A
int rabh[TM] = rxa / CW;
int raw[TM] = rxa % CW;
int rab[TM] = rabh / CH;
int rah[TM] = rabh % CH;
rah = rah * UPAW - off_uah;
raw = raw * UPAH - off_uaw;
int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int rak[TK] = *(ADELTA + rk);
A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :];
// pointers for B
int rbk[TK] = rk;
int rbn[TN] = ryb;
B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_c;
// pointers for A look-up table
int rklut[TK] = rk % LUT_SIZE;
int* padiff[TK] = ADIFF + rklut;
int* padelta[TK] = ADELTA + TK + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w;
int adiff[TK] = *padiff;
int adelta[TK] = *padelta;
// reduction loop
A_TYPE a[TM, TK] = *pa;
B_TYPE b[TK, TN] = *pb;
for(int k = K; k > 0; k = k - TK){
c += a @ b;
pa += adelta[newaxis, :];
pb += TK * ldb_c;
// increment A look-up table
padelta = padelta + adiff;
adelta = *padelta;
padiff = padiff + adiff;
adiff = *padiff;
// pre-fetches
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
}
// write back
int rxc[TM] = get_program_id(0) * TM + 0 ... TM;
int rc1[TN] = get_program_id(1) * TN + 0 ... TN;
int rcn[TM] = rxc / (CH*CW);
int rcpq[TM] = rxc % (CH*CW);
int rcp[TM] = rcpq / CW;
int rcq[TM] = rcpq % CW;
rcp = rcp * upsample_h + off_uch;
rcq = rcq * upsample_w + off_ucw;
int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q;
float* pc[TM, TN] = C + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = rc1 < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
*?(checkc)pc = c;
}
"""
kernel = triton.kernel(src, ['C'])
@staticmethod
def _unpack(idx, order, shape_b):
_123 = idx // shape_b[order[0]]
_0 = idx % shape_b[order[0]]
_23 = _123 // shape_b[order[1]]
_1 = _123 % shape_b[order[1]]
_3 = _23 // shape_b[order[2]]
_2 = _23 % shape_b[order[2]]
return _0, _1, _2, _3
@staticmethod
def _roundup(x, div):
return (x + div - 1) // div * div
@staticmethod
def _delta_a(upsample_d, upsample_h, upsample_w,
bc, bd, bh, bw,
ac, ad, ah, aw,
stride_a, shape_b,
TK):
# Parse the axes so that the reduction is done
# from the innermost dimension outward
order = sorted([bc, bd, bh, bw], reverse = True)
c, d, h, w = [order.index(x) for x in [bc, bd, bh, bw]]
# Size of the lookup table is the product of the 3 innermost dimensions
K = _conv._roundup(TK, shape_b[order[0]] * shape_b[order[1]] * shape_b[order[2]])
# Allocate temporary arrays
ud = np.arange(upsample_d, dtype=np.int32)[:, np.newaxis, np.newaxis, np.newaxis]
uh = np.arange(upsample_h, dtype=np.int32)[np.newaxis, :, np.newaxis, np.newaxis]
uw = np.arange(upsample_w, dtype=np.int32)[np.newaxis, np.newaxis, :, np.newaxis]
k = np.arange(K , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
# Find reduction indices at the current and next reduction indices
currentk = _conv._unpack(k , order, shape_b)
nextk = _conv._unpack(k + TK, order, shape_b)
# Compute memory stride
result = 0
result += (nextk[c] - currentk[c]) * stride_a[ac]
result += (nextk[d] - currentk[d]) * stride_a[ad]
result += (nextk[h] - currentk[h]) * stride_a[ah]
result += (nextk[w] - currentk[w]) * stride_a[aw]
# Initial k
ki = np.arange(TK , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :]
currentk = _conv._unpack(ki, order, shape_b)
resulti = 0
resulti += currentk[c] * stride_a[ac]
resulti += currentk[d] * stride_a[ad]
resulti += currentk[h] * stride_a[ah]
resulti += currentk[w] * stride_a[aw]
return np.concatenate((resulti, result), axis=-1)
@staticmethod
def _extract_strides(shape):
rank = len(shape)
ret = [1] * rank
for i in range(rank - 1, 0, -1):
ret[i-1] = ret[i] * shape[i]
return ret
@staticmethod
def _call(a, b,
pad_d, pad_h, pad_w,
stride_d, stride_h, stride_w,
upsample_d, upsample_h, upsample_w,
a_layout, b_layout, c_layout):
# input shapes
shape_a = list(triton.shape(a))
shape_b = list(triton.shape(b))
dim = len(shape_a) - 2
# indices
an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw']
bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs']
cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw']
# extract shapes
if dim == 2:
shape_a.insert(ad, 1)
if dim == 2:
shape_b.insert(bd, 1)
# output shape
shape_c = [0] * 5
shape_c[cn] = shape_a[an]
shape_c[ck] = shape_b[bk]
shape_c[cd] = (shape_a[ad]*upsample_d - shape_b[bd] + 1 + 2*pad_d + stride_d - 1) // stride_d
shape_c[ch] = (shape_a[ah]*upsample_h - shape_b[bh] + 1 + 2*pad_h + stride_h - 1) // stride_h
shape_c[cw] = (shape_a[aw]*upsample_w - shape_b[bw] + 1 + 2*pad_w + stride_w - 1) // stride_w
# strides
stride_a = _conv._extract_strides(shape_a)
stride_b = _conv._extract_strides(shape_b)
stride_c = _conv._extract_strides(shape_c)
# tiling parameters
TM = [32]
TN = [32]
TK = 8
# pointer deltas for a
delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w,
bc, bd, bh, bw,
ac, ad, ah, aw,
stride_a, shape_b,
TK)
delta_a = triton.fw.torch.from_numpy(delta_a).cuda()
# delta increments for a
inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32)
inc_a = ((inc_a + TK) % inc_a.size) - inc_a
inc_a = triton.fw.torch.from_numpy(inc_a).cuda()
# allocate output
if dim == 2:
shape_c.pop(cd)
c = triton.empty(shape_c, dtype=a.dtype)
if dim == 2:
shape_c.insert(cd, 1)
# execute kernel
trans_b = False
is_wgrad = False
is_blut = False
macros = {
'UPAR': 'stride_h' if is_wgrad else '1',
'UPAS': 'stride_w' if is_wgrad else '1',
'UPAH': '' if is_wgrad else 'stride_h',
'UPAW': '' if is_wgrad else 'stride_w',
'LUT_SIZE': delta_a.shape[-1],
'TM': TM, 'TN': TN, 'TK': TK,
'A_TYPE': 'float', 'B_TYPE': 'float'
}
MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw]
MATMUL_N = shape_c[ck]
MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw]
_conv.kernel(a, b, c,
# matrix multiplication shapes
MATMUL_M, MATMUL_N, MATMUL_K,
# shapes for a
shape_a[ah], shape_a[aw],
# shapes for b
shape_b[bh], shape_b[bw],
# chapes for c
shape_c[ch], shape_c[cw], shape_c[cn],
# strides for a
stride_a[an], stride_a[ac], stride_a[ad + 0], stride_a[ad + 1], stride_a[ad + 2],
# strides for b
stride_b[bc], stride_b[bd + 0], stride_b[bd + 1], stride_b[bd + 2], stride_b[bk],
# strides for c
stride_c[cn], stride_c[ck], stride_c[cd], stride_c[cd + 1], stride_c[cd + 2],
# padding
pad_h, pad_w,
# striding
stride_h, stride_w,
# upsampling
upsample_h, upsample_w,
0, 0, 0, 0, 0, 0,
# look-up table
delta_a, inc_a,
lambda opt: [triton.cdiv(MATMUL_M, opt.d('TM')), triton.cdiv(MATMUL_N, opt.d('TN'))],
**macros)
return c
@staticmethod
def forward(ctx, x, w,
pad_d = 0, pad_h = 0, pad_w = 0,
stride_d = 1, stride_h = 1, stride_w = 1,
upsample_d = 1, upsample_h = 1, upsample_w = 1,
layout_a = 'ncdhw', layout_b = 'ktrsc', layout_c = 'nkdhw'):
# save for backward
ctx.save_for_backward(x, w)
ctx.pad_d = pad_d
ctx.pad_h = pad_h
ctx.pad_w = pad_w
ctx.stride_d = stride_d
ctx.stride_h = stride_h
ctx.stride_w = stride_w
ctx.upsample_d = upsample_d
ctx.upsample_h = upsample_h
ctx.upsample_w = upsample_w
ctx.layout_a = layout_a
ctx.layout_b = layout_b
ctx.layout_c = layout_c
# return
return _conv._call(x, w,
pad_d, pad_h, pad_w,
stride_d, stride_h, stride_w,
upsample_d, upsample_h, upsample_w,
layout_a, layout_b, layout_c)
@staticmethod
def backward(ctx, dy):
x, w = ctx.saved_tensors
pad_d = ctx.pad_d
pad_h = ctx.pad_h
pad_w = ctx.pad_w
stride_d = ctx.stride_d
stride_h = ctx.stride_h
stride_w = ctx.stride_w
upsample_d = ctx.upsample_d
upsample_h = ctx.upsample_h
upsample_w = ctx.upsample_w
layout_a = ctx.layout_a
layout_b = ctx.layout_b
layout_c = ctx.layout_c
# TODO: Deal with this
dx_pad_d = 1
dx_pad_h = 1
dx_pad_w = 1
dx = _conv.call(dy, w,
dw_pad_d, dw_pad_h, dw_pad_w,
upsample_w, upsample_h, upsample_w,
stride_d, stride_h, stride_w,
'ncdhw', 'cktrs', 'nkdhw')
ret = [None] * 14
ret[0] = None
ret[1] = dw
return None,
conv = _conv.apply

View File

@@ -1,126 +0,0 @@
import triton
class _dot(triton.function):
src = """
void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C,
float alpha,
int M, int N, int K,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
bool checka[SHAPE_A] = rk[BROADCAST_AK] < K;
bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K;
TYPE a[SHAPE_A] = checka ? *pa : 0;
TYPE b[SHAPE_B] = checkb ? *pb : 0;
// reduction loop
float c[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
c += USE_A @ USE_B;
bool checka[SHAPE_A] = k > TK;
bool checkb[SHAPE_B] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
a = *?(checka)pa;
b = *?(checkb)pb;
}
//c = c * alpha;
// epilogue
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rxm[:, newaxis] * ldc + rxn[newaxis, :];
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
*?(checkc)pc = (TYPE[TM, TN])c;
}
"""
kernel = triton.kernel(src, ['C'])
@staticmethod
def _call(a, b, transpose_a, transpose_b, bench):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if transpose_a:
M, Ka = Ka, M
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# compute
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# macros -- not necessary but makes kernel source-code simpler
macros = {# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
_dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc,
grid, bench=bench,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [64], TN = [128], TK = [8], **macros)
return c
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
ctx.bench = bench
return _dot._call(a, b, transpose_a, transpose_b, bench)
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b
bench = ctx.bench
if not t_a and not t_b:
da = _dot._call(dy, b, False, True, bench)
db = _dot._call(a, dy, True, False, bench)
elif not t_a and t_b:
da = _dot._call(dy, b, False, False, bench)
db = _dot._call(dy, a, True, False, bench)
elif t_a and not t_b:
da = _dot._call(b, dy, False, True, bench)
db = _dot._call(a, dy, False, False, bench)
elif t_a and t_b:
da = _dot._call(b, dy, True, True, bench)
db = _dot._call(dy, a, True, True, bench)
else:
assert False
return da, db, None, None, None
dot = _dot.apply

View File

@@ -527,10 +527,10 @@ __global__ void {name}(
delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
self.args += [delta_b]
self.args += arrays
self.args += [lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]]
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
triton.cdiv(B, opt.d('TB')),
opt.d('TZ')]
# position of dynamic arguments
self.pos_a = 0
self.pos_b = 1
@@ -546,9 +546,8 @@ __global__ void {name}(
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
#TM, TN, TB = [128], [64], [1]
#print(TM, TN, TB)
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
# information on compute
self.dtype = dtype
self.flops = 2 * B * M * N * K
self.sym_a = sym_a
@@ -564,7 +563,7 @@ __global__ void {name}(
self.args[self.pos_a] = a
self.args[self.pos_b] = b
self.args[self.pos_c] = c
self.kernel(*self.args, bench=bench, **self.macros)
self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)