[PYTHON][TESTS][DOC] Various improvement of the API and code quality:
* Simplified `triton.kernel` API to achieve lower latency: > .data_ptr() must now be passed as kernel argument. No more implicit conversion from torch.tensor > compilation options are now constant attributes, i.e., opt.d('VAR') becomes opt.VAR > torch.device must now be passed explicitly to triton.kernel (no longer inferred from torch.tensor arguments) * C++ tests moved to `python/tests/` * C++ tutorial created in `tutorials/` * Python tutorial created in python/tutorials/ * Version changed to 1.0alpha * No longer copying C++ headers into the Python package * added python/triton/ops/ package for pre-written Triton ops
This commit is contained in:
@@ -1,39 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N) {
|
||||
|
||||
int pid = get_program_id(0);
|
||||
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
|
||||
bool check[TILE] = offset < N;
|
||||
|
||||
*pz = *px + *py;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
z = torch.empty_like(x).cuda()
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
||||
_add.kernel(z,x,y, N, grid=grid)
|
||||
return z
|
||||
|
||||
add = _add.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(900).cuda()
|
||||
y = torch.rand(900).cuda()
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
print(torch.allclose(za,zb))
|
@@ -1,202 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _conv(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
int *ADELTA,
|
||||
// memory strides
|
||||
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
||||
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
||||
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// unpack aggregate rows
|
||||
// m = (z, p, q)
|
||||
int rq[TM] = rm % QQ;
|
||||
int rzp[TM] = rm / QQ;
|
||||
int rp[TM] = rzp % PP;
|
||||
int rz[TM] = rzp / PP;
|
||||
// unpack aggregate reduction
|
||||
// k = (ci, r, s)
|
||||
int rs [TK] = rk % SS;
|
||||
int rcir[TK] = rk / SS;
|
||||
int rr [TK] = rcir % RR;
|
||||
int rci [TK] = rcir / RR;
|
||||
|
||||
// padding / striding
|
||||
int rh_0[TM] = rp * stride_h - pad_h;
|
||||
int rw_0[TM] = rq * stride_w - pad_w;
|
||||
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
|
||||
// pointers to lhs
|
||||
int offa[TM, TK] = rz [:, newaxis] * lda_z +
|
||||
rci[newaxis, :] * lda_ci +
|
||||
rh * lda_h +
|
||||
rw * 1;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
int* padelta[TK] = ADELTA + rk;
|
||||
// pointers to rhs
|
||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||
rr [:, newaxis] * ldb_r +
|
||||
rs [:, newaxis] * ldb_s +
|
||||
rn [newaxis, :] * 1;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
int total = 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
acc += a @ b;
|
||||
// increment A
|
||||
int adelta[TK] = *padelta;
|
||||
padelta += TK;
|
||||
pa += adelta[newaxis, :];
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
rcir = rk / SS;
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking B
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
rm = ridx * TM + 0 ... TM;
|
||||
rn = ridy * TN + 0 ... TN;
|
||||
rq = rm % QQ;
|
||||
rzp = rm / QQ;
|
||||
rp = rzp % PP;
|
||||
rz = rzp / PP;
|
||||
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
|
||||
rn [newaxis, :] * ldc_co+
|
||||
rp [:, newaxis] * ldc_p +
|
||||
rq [:, newaxis] * 1;
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def unpack(IDX, CI, R, S):
|
||||
s = IDX % S
|
||||
cr = IDX // S
|
||||
r = cr % R
|
||||
ci = cr // R
|
||||
return ci, r, s
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, pad, stride, time):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
# shapes
|
||||
Z, CI, H, W = a.shape
|
||||
_, R, S, CO = b.shape
|
||||
P = (H + 2*pad[0] - R)//stride[0] + 1
|
||||
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
||||
# compile kernel
|
||||
if dtype not in _conv.kernel:
|
||||
TK = 8
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [16, 32, 64, 128],
|
||||
'TN' : [16, 32, 64, 128],
|
||||
'TK' : [TK],
|
||||
'TZ' : [1],
|
||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI*R*S)
|
||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + TK, CI, R, S)
|
||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = torch.empty([Z, CO, P, Q], dtype=dtype)
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
|
||||
triton.cdiv(CO, opt.d('TN'))]
|
||||
time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S,
|
||||
pad[0], pad[1], stride[0], stride[1],
|
||||
delta,
|
||||
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
|
||||
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
grid=grid, bench=100)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
conv = _conv.apply
|
||||
torch.manual_seed(0)
|
||||
Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3
|
||||
pad = (1, 1)
|
||||
stride = (1, 1)
|
||||
a = torch.rand((Z, CI, H, W)).cuda()
|
||||
b = torch.rand((CI, R, S, CO)).cuda()
|
||||
time = [None]
|
||||
cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1])
|
||||
c = conv(a, b, pad, stride, time)
|
||||
print((cc - c).abs().max() / max(cc.max(), c.max()))
|
||||
print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12)
|
||||
#zc = torch.matmul(a,b)
|
||||
#zc_ = dot(a,b)
|
@@ -1,70 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _copy(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void copy(TYPE * X, TYPE * Y,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int ldx __multipleof(8)) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; //(5)
|
||||
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldx; //(6)
|
||||
|
||||
*py = *px;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = None ### initialize later when we know the sizes
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
||||
M, N = x.shape
|
||||
|
||||
ldx = N;
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
y = torch.empty((M,N)).cuda()
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _copy.kernel is None:
|
||||
_copy.kernel = triton.kernel(_copy.src, defines=defines, num_warps=[4])
|
||||
|
||||
_copy.kernel(x, y, M, N, ldx, grid=grid)
|
||||
|
||||
return y
|
||||
|
||||
copy = _copy.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(8,4).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
ya = x
|
||||
yb = copy(x)
|
||||
|
||||
print()
|
||||
print(ya)
|
||||
print()
|
||||
print(yb)
|
||||
print(torch.allclose(ya, yb))
|
||||
|
||||
print(ya == yb)
|
@@ -1,143 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
#define STM 4
|
||||
#define STN 4
|
||||
|
||||
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __aligned(16),
|
||||
float alpha,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int K __retune __multipleof(16),
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc __multipleof(8)) {
|
||||
// prologue
|
||||
int pid = get_program_id(0);
|
||||
int pidz = get_program_id(2);
|
||||
int gridm = M / TM;
|
||||
int gridn = N / TN;
|
||||
int stgridm = (gridm + STM - 1) / STM;
|
||||
int stgridn = (gridn + STN - 1) / STN;
|
||||
int stid = pid / (STM * STN);
|
||||
int laneid = pid % (STM * STN);
|
||||
int stm = stid / stgridn;
|
||||
int stn = stid % stgridn;
|
||||
int lanem = laneid / STN;
|
||||
int lanen = laneid % STN;
|
||||
int pidm = stm*STM + lanem;
|
||||
int pidn = stn*STN + lanen;
|
||||
int rm[TM] = pidm * TM + 0 ... TM;
|
||||
int rn[TN] = pidn * TN + 0 ... TN;
|
||||
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = pidz * K + 0 ... TK;
|
||||
|
||||
// pointers to operands
|
||||
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
bool checka[TM, TK] = k > TK;
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
pa += TK * STRIDE_AK;
|
||||
pb += TK * STRIDE_BK;
|
||||
acc += a @ b;
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
int rxm[TM] = pidm * TM + 0 ... TM;
|
||||
int rxn[TN] = pidn * TN + 0 ... TN;
|
||||
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + pid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
c = _dot._call(a,b)
|
||||
return c
|
||||
|
||||
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
if dtype not in _dot.kernel:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
|
||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||
'TM' : [128],
|
||||
'TN' : [128],
|
||||
'TK' : [32],
|
||||
'TZ' : [1]
|
||||
}
|
||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||
kernel = _dot.kernel[dtype]
|
||||
# allocate output
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
c = torch.empty([M,N], dtype=dtype, device=a.device)
|
||||
print(kernel.asm('sass', c.device))
|
||||
print(kernel.asm('ptx', c.device))
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
|
||||
time = kernel(a, b, c, 1., M, N, K,
|
||||
a.stride(0), b.stride(0), c.stride(0), grid=grid)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
M, N, K = 4096, 4096, 4096
|
||||
a = torch.rand((M, K)).cuda().half()
|
||||
b = torch.rand((K, N)).cuda().half()
|
||||
|
||||
#a[:] = 1
|
||||
#b[:] = 1
|
||||
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
print(torch.allclose(zc, zc_))
|
@@ -1,76 +0,0 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _transpose(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void transpose(TYPE * X, TYPE * Y,
|
||||
int M __retune,
|
||||
int N __retune,
|
||||
int ldx __multipleof(8), int ldy __multipleof(8)) {
|
||||
// extract program ID
|
||||
int pidm = get_program_id(0); //(1)
|
||||
int pidn = get_program_id(1); //(2)
|
||||
|
||||
// create 1D range along the two matrix's axes
|
||||
int rm[TM] = pidm * TM + 0 ... TM; //(3)
|
||||
int rn[TN] = pidn * TN + 0 ... TN; //(4)
|
||||
|
||||
// create 2D array of pointers
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :]; //(5)
|
||||
TYPE* py[TN, TM] = Y + rm[newaxis, :] + rn[:, newaxis] * ldy; //(6)
|
||||
|
||||
// create bounds-checking mask
|
||||
bool checkx[TM, TN] = (rm[:, newaxis] < M) && (rn[newaxis, :] < N); //(7a)
|
||||
bool checky[TN, TM] = (rn[:, newaxis] < N) && (rm[newaxis, :] < M); //(7b)
|
||||
|
||||
// conditional write-back using the conditional dereferencing operatior '*?()'
|
||||
*?(checky)py = ^(*?(checkx)px); //(7)
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = None ### initialize later when we know the sizes
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
|
||||
M, N = x.shape
|
||||
|
||||
ldx = N
|
||||
ldy = M
|
||||
|
||||
dtype = x.dtype
|
||||
|
||||
y = torch.empty((N,M)).cuda()
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
}
|
||||
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
if _transpose.kernel is None:
|
||||
_transpose.kernel = triton.kernel(_transpose.src, defines=defines, num_warps=[4])
|
||||
|
||||
_transpose.kernel(x, y, M, N, ldx, ldy, grid=grid)
|
||||
|
||||
return y
|
||||
|
||||
transpose = _transpose.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(1024,128).cuda()
|
||||
|
||||
print(x)
|
||||
|
||||
ya = torch.t(x)
|
||||
yb = transpose(x)
|
||||
print()
|
||||
print(ya)
|
||||
print()
|
||||
print(yb)
|
||||
print(torch.allclose(ya, yb))
|
||||
|
||||
print(ya == yb)
|
Reference in New Issue
Block a user