[EXAMPLES][TUTORIAL] Changed to new triton.kernel API
This commit is contained in:
committed by
Philippe Tillet
parent
c33d6d15f5
commit
4ccd78f1a6
@@ -16,16 +16,14 @@ class _dot(torch.autograd.Function):
|
|||||||
// accumulator
|
// accumulator
|
||||||
float c[TM, TN] = 0;
|
float c[TM, TN] = 0;
|
||||||
|
|
||||||
//pointers
|
// pointers
|
||||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda;
|
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda;
|
||||||
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1;
|
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1;
|
||||||
|
|
||||||
for(int k=K; k>0; k-=TK) {
|
for(int k=K; k>0; k-=TK) {
|
||||||
TYPE a[TM, TK] = *pa;
|
TYPE a[TM, TK] = *pa;
|
||||||
TYPE b[TK, TN] = *pb;
|
TYPE b[TK, TN] = *pb;
|
||||||
|
|
||||||
c += a @ b;
|
c += a @ b;
|
||||||
|
|
||||||
pa = pa + TK * 1;
|
pa = pa + TK * 1;
|
||||||
pb = pb + TK * ldb;
|
pb = pb + TK * ldb;
|
||||||
}
|
}
|
||||||
@@ -40,32 +38,35 @@ class _dot(torch.autograd.Function):
|
|||||||
c = _dot._call(a,b)
|
c = _dot._call(a,b)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
kernel = dict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _call(a, b):
|
def _call(a, b):
|
||||||
|
# shapes
|
||||||
M, K = a.shape
|
M, K = a.shape
|
||||||
K, N = b.shape
|
K, N = b.shape
|
||||||
|
# leading dimension
|
||||||
lda = K
|
lda = K
|
||||||
ldb = N
|
ldb = N
|
||||||
ldc = N
|
ldc = N
|
||||||
|
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
|
# create kernel if necessary
|
||||||
|
if dtype not in _dot.kernel:
|
||||||
|
defines = {
|
||||||
|
'TYPE' : dtype,
|
||||||
|
'TM' : [64, 128],
|
||||||
|
'TN' : [64, 128],
|
||||||
|
'TK' : [8, 16],
|
||||||
|
}
|
||||||
|
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[2, 4], defines=defines)
|
||||||
|
kernel = _dot.kernel[dtype]
|
||||||
|
# allocate output
|
||||||
c = triton.empty([M,N], dtype=dtype)
|
c = triton.empty([M,N], dtype=dtype)
|
||||||
|
# enqueue
|
||||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
grid = lambda opt: [triton.cdiv(M, opt.d('TM')),
|
||||||
|
triton.cdiv(N, opt.d('TN'))]
|
||||||
defines= {
|
kernel(a, b, c, M, N, K, lda, ldb, ldc, grid=grid)
|
||||||
'TYPE' : dtype,
|
|
||||||
'TM' : [32,64,128],
|
|
||||||
'TN' : [32,64,128],
|
|
||||||
'TK' : [8],
|
|
||||||
}
|
|
||||||
|
|
||||||
_dot.kernel = triton.kernel(_dot.src, defines=defines)
|
|
||||||
_dot.kernel(a, b, c, M, N, K, lda, ldb, ldc,
|
|
||||||
grid=grid, num_warps=4, defines=defines)
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
@@ -81,4 +82,4 @@ b = torch.rand((K, N)).cuda()
|
|||||||
zc = torch.matmul(a,b)
|
zc = torch.matmul(a,b)
|
||||||
zc_ = dot(a,b)
|
zc_ = dot(a,b)
|
||||||
|
|
||||||
print(torch.allclose(zc, zc_))
|
print(torch.allclose(zc, zc_))
|
Reference in New Issue
Block a user