diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/mat_mul.py index a8230207f..75650a6cb 100644 --- a/python/examples/tutorials/mat_mul.py +++ b/python/examples/tutorials/mat_mul.py @@ -16,16 +16,14 @@ class _dot(torch.autograd.Function): // accumulator float c[TM, TN] = 0; - //pointers + // pointers TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda; TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1; for(int k=K; k>0; k-=TK) { TYPE a[TM, TK] = *pa; TYPE b[TK, TN] = *pb; - c += a @ b; - pa = pa + TK * 1; pb = pb + TK * ldb; } @@ -40,32 +38,35 @@ class _dot(torch.autograd.Function): c = _dot._call(a,b) return c + + kernel = dict() @staticmethod def _call(a, b): + # shapes M, K = a.shape K, N = b.shape - + # leading dimension lda = K ldb = N ldc = N - dtype = a.dtype - + # create kernel if necessary + if dtype not in _dot.kernel: + defines = { + 'TYPE' : dtype, + 'TM' : [64, 128], + 'TN' : [64, 128], + 'TK' : [8, 16], + } + _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[2, 4], defines=defines) + kernel = _dot.kernel[dtype] + # allocate output c = triton.empty([M,N], dtype=dtype) - - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] - - defines= { - 'TYPE' : dtype, - 'TM' : [32,64,128], - 'TN' : [32,64,128], - 'TK' : [8], - } - - _dot.kernel = triton.kernel(_dot.src, defines=defines) - _dot.kernel(a, b, c, M, N, K, lda, ldb, ldc, - grid=grid, num_warps=4, defines=defines) + # enqueue + grid = lambda opt: [triton.cdiv(M, opt.d('TM')), + triton.cdiv(N, opt.d('TN'))] + kernel(a, b, c, M, N, K, lda, ldb, ldc, grid=grid) return c @@ -81,4 +82,4 @@ b = torch.rand((K, N)).cuda() zc = torch.matmul(a,b) zc_ = dot(a,b) -print(torch.allclose(zc, zc_)) +print(torch.allclose(zc, zc_)) \ No newline at end of file