[GENERAL] Merged v1.0alpha into master. Added features are:

- A100 support via mma.16816
- Thread swizzling for conflict-free shared memory accesses without
padding
- Complete overhaul of the LLVM code generation in
codegen/selection/generator.cc to remove overengineering
- Added debugging capabilities in the Python binding
- Compilation error for kernels that spill
This commit is contained in:
Philippe Tillet
2021-01-11 19:20:34 -05:00
parent c0bc7ed8b0
commit 083bbd1e8d
75 changed files with 2688 additions and 4512 deletions

View File

@@ -171,7 +171,7 @@ class _conv(torch.autograd.Function):
_conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines))
delta, kernel = _conv.kernel[dtype]
# allocate output
c = triton.empty([Z, CO, P, Q], dtype=dtype)
c = torch.empty([Z, CO, P, Q], dtype=dtype)
# enqueue
grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')),
triton.cdiv(CO, opt.d('TN'))]

View File

@@ -3,6 +3,9 @@ import triton
class _dot(torch.autograd.Function):
src = """
#define STM 4
#define STN 4
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE * B __noalias __readonly __aligned(16),
TYPE * C __noalias __aligned(16),
@@ -14,20 +17,26 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
int ldb __multipleof(8),
int ldc __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int gridx = M / TM;
int gridy = N / TN;
int rid = ridx + ridy * gridx;
ridx = rid / gridy;
ridy = rid % gridy;
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
int pid = get_program_id(0);
int pidz = get_program_id(2);
int gridm = M / TM;
int gridn = N / TN;
int stgridm = (gridm + STM - 1) / STM;
int stgridn = (gridn + STN - 1) / STN;
int stid = pid / (STM * STN);
int laneid = pid % (STM * STN);
int stm = stid / stgridn;
int stn = stid % stgridn;
int lanem = laneid / STN;
int lanen = laneid % STN;
int pidm = stm*STM + lanem;
int pidn = stn*STN + lanen;
int rm[TM] = pidm * TM + 0 ... TM;
int rn[TN] = pidn * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
int rk[TK] = pidz * K + 0 ... TK;
// pointers to operands
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
@@ -44,11 +53,11 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
bool checka[TM, TK] = k > TK;
bool checkb[TK, TN] = k > TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
acc += a @ b;
a = *?(checka)pa;
b = *?(checkb)pb;
}
@@ -56,8 +65,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
TYPE c[TM, TN] = acc;
// epilogue
int rxm[TM] = ridx * TM + 0 ... TM;
int rxn[TN] = ridy * TN + 0 ... TN;
int rxm[TM] = pidm * TM + 0 ... TM;
int rxn[TN] = pidn * TN + 0 ... TN;
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
@@ -66,7 +75,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *plock = locks + pid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
@@ -100,7 +109,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
'TM' : [128],
'TN' : [128],
'TK' : [16],
'TK' : [32],
'TZ' : [1]
}
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
@@ -109,9 +118,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16),
M, K = a.shape
K, N = b.shape
c = torch.empty([M,N], dtype=dtype, device=a.device)
print(kernel.asm('sass', c.device))
print(kernel.asm('ptx', c.device))
# enqueue
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
triton.cdiv(N, opt.d('TN'))]
grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))]
time = kernel(a, b, c, 1., M, N, K,
a.stride(0), b.stride(0), c.stride(0), grid=grid)
return c
@@ -130,6 +140,4 @@ b = torch.rand((K, N)).cuda().half()
zc = torch.matmul(a,b)
zc_ = dot(a,b)
print(torch.allclose(zc, zc_))