[EXAMPLES] Added conv2d example
This commit is contained in:
committed by
Philippe Tillet
parent
46297a949f
commit
68c18238a9
200
python/examples/tutorials/conv2d.py
Normal file
200
python/examples/tutorials/conv2d.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _conv(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||
TYPE *B __noalias __readonly __aligned(16),
|
||||
TYPE *C __noalias __aligned(16),
|
||||
float alpha,
|
||||
// equivalent matmul
|
||||
int M, int N, int K,
|
||||
// convolution properties
|
||||
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||
// pointer increment
|
||||
int *ADELTA,
|
||||
// memory strides
|
||||
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
||||
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
||||
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
||||
// prologue
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int ridz = get_program_id(2);
|
||||
/*
|
||||
int gridx = M / TM;
|
||||
int gridy = N / TN;
|
||||
int rid = ridx + ridy * gridx;
|
||||
ridx = rid / gridy;
|
||||
ridy = rid % gridy;
|
||||
*/
|
||||
int rm[TM] = ridx * TM + 0 ... TM;
|
||||
int rn[TN] = ridy * TN + 0 ... TN;
|
||||
// reduction splitting
|
||||
K = K / TZ;
|
||||
int rk[TK] = ridz * K + 0 ... TK;
|
||||
|
||||
// unpack aggregate rows
|
||||
// m = (z, p, q)
|
||||
int rq[TM] = rm % QQ;
|
||||
int rzp[TM] = rm / QQ;
|
||||
int rp[TM] = rzp % PP;
|
||||
int rz[TM] = rzp / PP;
|
||||
// unpack aggregate reduction
|
||||
// k = (ci, r, s)
|
||||
int rs [TK] = rk % SS;
|
||||
int rcir[TK] = rk / SS;
|
||||
int rr [TK] = rcir % RR;
|
||||
int rci [TK] = rcir / RR;
|
||||
|
||||
// padding / striding
|
||||
int rh_0[TM] = rp * stride_h - pad_h;
|
||||
int rw_0[TM] = rq * stride_w - pad_w;
|
||||
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
|
||||
// pointers to lhs
|
||||
int offa[TM, TK] = rz [:, newaxis] * lda_z +
|
||||
rci[newaxis, :] * lda_ci +
|
||||
rh * lda_h +
|
||||
rw * 1;
|
||||
TYPE* pa[TM, TK] = A + offa;
|
||||
int* padelta[TK] = ADELTA + rk;
|
||||
// pointers to rhs
|
||||
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||
rr [:, newaxis] * ldb_r +
|
||||
rs [:, newaxis] * ldb_s +
|
||||
rn [newaxis, :] * 1;
|
||||
TYPE* pb[TK, TN] = B + offb;
|
||||
|
||||
// prefetches operands
|
||||
bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||
|
||||
// reduction loop
|
||||
float acc[TM, TN] = 0;
|
||||
for(int k = K; k > 0; k -= TK){
|
||||
acc += a @ b;
|
||||
// increment A
|
||||
int adelta[TK] = *padelta;
|
||||
padelta += TK;
|
||||
pa += adelta[newaxis, :];
|
||||
// increment B
|
||||
pb += TK * ldb_s;
|
||||
// bounds-checking A
|
||||
rk += TK;
|
||||
rs = rk % SS;
|
||||
rcir = rk / SS;
|
||||
rr = rcir % RR;
|
||||
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
||||
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
||||
bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||
// bounds-checking B
|
||||
bool checkb[TK, TN] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = *?(checkb)pb;
|
||||
}
|
||||
acc = acc * alpha;
|
||||
TYPE c[TM, TN] = acc;
|
||||
|
||||
// epilogue
|
||||
rm = ridx * TM + 0 ... TM;
|
||||
rn = ridy * TN + 0 ... TN;
|
||||
rq = rm % QQ;
|
||||
rzp = rm / QQ;
|
||||
rp = rzp % PP;
|
||||
rz = rzp / PP;
|
||||
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
|
||||
rn [newaxis, :] * ldc_co+
|
||||
rp [:, newaxis] * ldc_p +
|
||||
rq [:, newaxis] * 1;
|
||||
TYPE* pc[TM, TN] = C + offc;
|
||||
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
|
||||
|
||||
#if (TZ==1)
|
||||
*?(checkc) pc = c;
|
||||
#else
|
||||
// accumulate partial result using spin-locks
|
||||
int *plock = locks + rid;
|
||||
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *?(checkc)pc;
|
||||
atomic_xchg(pcount, (count + 1) % TZ);
|
||||
atomic_xchg(plock, 0);
|
||||
#endif
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = dict()
|
||||
|
||||
@staticmethod
|
||||
def unpack(IDX, CI, R, S):
|
||||
s = IDX % S
|
||||
cr = IDX // S
|
||||
r = cr % R
|
||||
ci = cr // R
|
||||
return ci, r, s
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, pad, stride, time):
|
||||
# create kernel if necessary
|
||||
dtype = a.dtype
|
||||
# shapes
|
||||
Z, CI, H, W = a.shape
|
||||
_, R, S, CO = b.shape
|
||||
P = (H + 2*pad[0] - R)//stride[0] + 1
|
||||
Q = (W + 2*pad[1] - S)//stride[1] + 1
|
||||
# compile kernel
|
||||
if dtype not in _conv.kernel:
|
||||
defines = {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [64, 128],
|
||||
'TN' : [64, 128],
|
||||
'TK' : [8],
|
||||
'TZ' : [1],
|
||||
'LUTSIZE' : 4*CI*R*S,
|
||||
'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R,
|
||||
}
|
||||
idx = torch.arange(CI*R*S)
|
||||
ci, r, s = _conv.unpack(idx, CI, R, S)
|
||||
nci, nr, ns = _conv.unpack(idx + 8, CI, R, S)
|
||||
delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3)
|
||||
delta = delta.type(torch.int32).cuda()
|
||||
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
|
||||
delta, kernel = _conv.kernel[dtype]
|
||||
# allocate output
|
||||
c = triton.empty([Z, CO, P, Q], dtype=dtype)
|
||||
# enqueue
|
||||
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
|
||||
triton.cdiv(CO, opt.d('TN'))]
|
||||
time[0] = kernel(a, b, c, 1., Z*P*Q, CO, CI*R*S,
|
||||
pad[0], pad[1], stride[0], stride[1],
|
||||
delta,
|
||||
a.stride(0), a.stride(1), a.stride(2), a.stride(3),
|
||||
b.stride(0), b.stride(1), b.stride(2), b.stride(3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
grid=grid, bench=100)
|
||||
return c
|
||||
|
||||
|
||||
|
||||
conv = _conv.apply
|
||||
torch.manual_seed(0)
|
||||
Z, H, W, CI, CO, R, S = 1, 32, 64, 256, 2048, 3, 3
|
||||
pad = (1, 1)
|
||||
stride = (1, 1)
|
||||
a = torch.rand((Z, CI, H, W)).cuda()
|
||||
b = torch.rand((CI, R, S, CO)).cuda()
|
||||
time = [None]
|
||||
cc = torch.nn.functional.conv2d(a, b.permute(3,0,1,2), None, stride, pad, [1, 1])
|
||||
c = conv(a, b, pad, stride, time)
|
||||
print((cc - c).abs().max() / max(cc.max(), c.max()))
|
||||
print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12)
|
||||
#zc = torch.matmul(a,b)
|
||||
#zc_ = dot(a,b)
|
Reference in New Issue
Block a user