[EXAMPLES] Improved mat_mul example
This commit is contained in:
committed by
Philippe Tillet
parent
68c18238a9
commit
89e456107b
@@ -3,34 +3,79 @@ import triton
|
|||||||
|
|
||||||
class _dot(torch.autograd.Function):
|
class _dot(torch.autograd.Function):
|
||||||
src = """
|
src = """
|
||||||
__global__ void dot(TYPE *A, TYPE *B, TYPE *C, int M, int N, int K,
|
__global__ void dot(TYPE *A __noalias __readonly __aligned(16),
|
||||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
TYPE *B __noalias __readonly __aligned(16),
|
||||||
int pm = get_program_id(0);
|
TYPE *C __noalias __aligned(16),
|
||||||
int pn = get_program_id(1);
|
float alpha,
|
||||||
|
int M, int N, int K,
|
||||||
|
int lda __multipleof(8),
|
||||||
|
int ldb __multipleof(8),
|
||||||
|
int ldc __multipleof(8)) {
|
||||||
|
// prologue
|
||||||
|
int ridx = get_program_id(0);
|
||||||
|
int ridy = get_program_id(1);
|
||||||
|
int ridz = get_program_id(2);
|
||||||
|
int gridx = M / TM;
|
||||||
|
int gridy = N / TN;
|
||||||
|
int rid = ridx + ridy * gridx;
|
||||||
|
ridx = rid / gridy;
|
||||||
|
ridy = rid % gridy;
|
||||||
|
int rm[TM] = ridx * TM + 0 ... TM;
|
||||||
|
int rn[TN] = ridy * TN + 0 ... TN;
|
||||||
|
|
||||||
// ranges
|
// reduction splitting
|
||||||
int rm[TM] = pm * TM + 0 ... TM;
|
K = K / TZ;
|
||||||
int rn[TN] = pn * TN + 0 ... TN;
|
int rk[TK] = ridz * K + 0 ... TK;
|
||||||
int rk[TK] = 0 ... TK;
|
|
||||||
|
|
||||||
// accumulator
|
// pointers to operands
|
||||||
float c[TM, TN] = 0;
|
int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM;
|
||||||
|
int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN;
|
||||||
|
TYPE* pa[TM, TK] = A + offa;
|
||||||
|
TYPE* pb[TK, TN] = B + offb;
|
||||||
|
|
||||||
// pointers
|
// prefetches operands
|
||||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda;
|
bool checka[TM, TK] = rk[newaxis, :] < K;
|
||||||
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1;
|
bool checkb[TK, TN] = rk[:, newaxis] < K;
|
||||||
|
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||||
|
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||||
|
|
||||||
for(int k=K; k>0; k-=TK) {
|
// reduction loop
|
||||||
TYPE a[TM, TK] = *pa;
|
float acc[TM, TN] = 0;
|
||||||
TYPE b[TK, TN] = *pb;
|
for(int k = K; k > 0; k -= TK){
|
||||||
c += a @ b;
|
acc += a @ b;
|
||||||
pa = pa + TK * 1;
|
bool checka[TM, TK] = k > TK;
|
||||||
pb = pb + TK * ldb;
|
bool checkb[TK, TN] = k > TK;
|
||||||
}
|
pa += TK * STRIDE_AK;
|
||||||
|
pb += TK * STRIDE_BK;
|
||||||
|
a = *?(checka)pa;
|
||||||
|
b = *?(checkb)pb;
|
||||||
|
}
|
||||||
|
acc = acc * alpha;
|
||||||
|
TYPE c[TM, TN] = acc;
|
||||||
|
|
||||||
TYPE* pc[TM,TN] = C + rn[newaxis, :] + rm[:,newaxis] * ldc;
|
// epilogue
|
||||||
*pc = c;
|
int rxm[TM] = get_program_id(0) * TM + 0 ... TM;
|
||||||
}
|
int rxn[TN] = get_program_id(1) * TN + 0 ... TN;
|
||||||
|
int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :];
|
||||||
|
TYPE* pc[TM, TN] = C + offc;
|
||||||
|
bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N);
|
||||||
|
|
||||||
|
#if (TZ==1)
|
||||||
|
*?(checkc) pc = c;
|
||||||
|
#else
|
||||||
|
// accumulate partial result using spin-locks
|
||||||
|
int *plock = locks + rid;
|
||||||
|
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||||
|
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||||
|
int count = *pcount;
|
||||||
|
if(count == 0)
|
||||||
|
*?(checkc) pc = c;
|
||||||
|
else
|
||||||
|
*?(checkc) pc = c + *?(checkc)pc;
|
||||||
|
atomic_xchg(pcount, (count + 1) % TZ);
|
||||||
|
atomic_xchg(plock, 0);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -43,30 +88,31 @@ class _dot(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _call(a, b):
|
def _call(a, b):
|
||||||
# shapes
|
|
||||||
M, K = a.shape
|
|
||||||
K, N = b.shape
|
|
||||||
# leading dimension
|
|
||||||
lda = K
|
|
||||||
ldb = N
|
|
||||||
ldc = N
|
|
||||||
dtype = a.dtype
|
|
||||||
# create kernel if necessary
|
# create kernel if necessary
|
||||||
|
dtype = a.dtype
|
||||||
if dtype not in _dot.kernel:
|
if dtype not in _dot.kernel:
|
||||||
defines = {
|
defines = {
|
||||||
'TYPE' : dtype,
|
'TYPE' : dtype,
|
||||||
|
'STRIDE_AM': '1', 'STRIDE_AK': 'lda',
|
||||||
|
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||||
'TM' : [64, 128],
|
'TM' : [64, 128],
|
||||||
'TN' : [64, 128],
|
'TN' : [64, 128],
|
||||||
'TK' : [8, 16],
|
'TK' : [8, 16],
|
||||||
|
'TZ' : [1]
|
||||||
}
|
}
|
||||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[2, 4], defines=defines)
|
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||||
kernel = _dot.kernel[dtype]
|
kernel = _dot.kernel[dtype]
|
||||||
# allocate output
|
# allocate output
|
||||||
|
M, K = a.shape
|
||||||
|
K, N = b.shape
|
||||||
c = triton.empty([M,N], dtype=dtype)
|
c = triton.empty([M,N], dtype=dtype)
|
||||||
# enqueue
|
# enqueue
|
||||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
|
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
|
||||||
triton.cdiv(N, opt.d('TN'))]
|
triton.cdiv(N, opt.d('TN'))]
|
||||||
kernel(a, b, c, M, N, K, lda, ldb, ldc, grid=grid)
|
time = kernel(a, b, c, 1., M, N, K,
|
||||||
|
a.stride(0), b.stride(0), c.stride(0),
|
||||||
|
grid=grid, bench=100)
|
||||||
|
print(time)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
@@ -74,12 +120,12 @@ dot = _dot.apply
|
|||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
M, N, K = 128, 512, 256
|
M, N, K = 2048, 2048, 2048
|
||||||
a = torch.rand((M, K)).cuda()
|
a = torch.rand((M, K)).cuda()
|
||||||
b = torch.rand((K, N)).cuda()
|
b = torch.rand((K, N)).cuda()
|
||||||
|
|
||||||
|
|
||||||
zc = torch.matmul(a,b)
|
#zc = torch.matmul(a,b)
|
||||||
zc_ = dot(a,b)
|
zc_ = dot(a,b)
|
||||||
|
|
||||||
print(torch.allclose(zc, zc_))
|
#print(torch.allclose(zc, zc_))
|
Reference in New Issue
Block a user