[DOCS] Matmul and vecadd working examples

This commit is contained in:
jack-willturner
2020-05-04 16:25:17 +01:00
committed by Philippe Tillet
parent ce4a4728f5
commit 32819dea51
6 changed files with 159 additions and 27 deletions

View File

@@ -58,7 +58,7 @@ As you will see, a wrapper for the above Triton function can be created in just
"""
# create callable kernel for the source-code
# options: 4 warps and a -DTILE=1024
kernel = triton.kernel(src, defines = {'TILE': 1024}; num_warps = [4])
kernel = triton.kernel(src, defines = {'TILE': 1024}, num_warps = [4])
# Forward pass
@staticmethod
@@ -88,6 +88,7 @@ As you will see, a wrapper for the above Triton function can be created in just
zb = add(x, y)
diff = (za - zb).abs().max()
print(diff)
print(torch.allclose(za,zb))
Executing the above code will:
@@ -97,3 +98,5 @@ Executing the above code will:
- Call the resulting custom op
In other words, the first program run will generate and cache a bunch of files in $HOME/.triton/cache, but subsequent runs should be just as fast as using a handwritten custom operation.
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/vec_add.py>`_.

View File

@@ -35,7 +35,7 @@ Matrix multiplications of the form `C = A x B` can be implemented in Triton-C fa
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
// matrix-multiply accumulate
c += dot(a, b); //(11)
c += a @ b; //(11)
// increment pointers
pa = pa + TK * 1; //(12)
pb = pb + TK * ldb; //(13)
@@ -88,7 +88,7 @@ The purpose of pre-fetching is to overlap the update of the accumulator `c` with
TYPE a[TM, TK] = *pa; //(9)
TYPE b[TK, TN] = *pb; //(10)
for(int k = K; k > 0; k-= TK){
c += dot(a, b);
c += a @ b;
pa = pa + TK * 1;
pb = pb + TK * ldb;
// don't prefetch last iteration
@@ -144,7 +144,7 @@ It is common for optimized matrix-multiplication implementations (e.g., BLAS) to
TYPE b[SHAPE_B] = (*pb);
// reduction loop
for(int k = K; k > 0; k-= TK){
c += dot(USE_A, USE_B);
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
@@ -182,3 +182,5 @@ Auto-tuning can also be handled using pre-processor macros:
// Auto-tuning TM and TN in {32, 64, 128}; TK in {8, 16}
-DTM=[32, 64, 128] -DTN=[32, 64, 128] -DTK=[8, 16]
A runnable version of this kernel is available `here <https://github.com/ptillet/triton/tree/master/python/examples/tutorials/mat_mul.py>`_.

View File

@@ -95,7 +95,7 @@ Bounds Checking
==========================
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle thie situation, as shown below:
You might have noticed that the above code will fail when `M` and `N` are not multiples of `TM` and `TN` respectively. Fortunately, the above kernel can be slightly modified to handle this situation, as shown below:
.. code-block:: C

View File

@@ -110,9 +110,9 @@ However, in practice only A, B are provided by the user, and all the other :code
'TYPE' : dtype,
'AT' : transpose_a,
'BT' : transpose_b,
'TM' : [32, 64, 128]
'TN' : [32, 64, 128]
'TK' : [8]
'TM' : [32, 64, 128],
'TN' : [32, 64, 128],
'TK' : [8],
# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',

View File

@@ -0,0 +1,84 @@
import torch
import triton
class _dot(torch.autograd.Function):
src = """
__global__ void dot(TYPE *A, TYPE *B, TYPE *C, int M, int N, int K,
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
int pm = get_program_id(0);
int pn = get_program_id(1);
// ranges
int rm[TM] = pm * TM + 0 ... TM;
int rn[TN] = pn * TN + 0 ... TN;
int rk[TK] = 0 ... TK;
// accumulator
float c[TM, TN] = 0;
//pointers
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda;
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1;
for(int k=K; k>0; k-=TK) {
TYPE a[TM, TK] = *pa;
TYPE b[TK, TN] = *pb;
c += a @ b;
pa = pa + TK * 1;
pb = pb + TK * ldb;
}
TYPE* pc[TM,TN] = C + rn[newaxis, :] + rm[:,newaxis] * ldc;
*pc = c;
}
"""
@staticmethod
def forward(ctx, a, b):
c = _dot._call(a,b)
return c
@staticmethod
def _call(a, b):
M, K = a.shape
K, N = b.shape
lda = K
ldb = N
ldc = N
dtype = a.dtype
c = triton.empty([M,N], dtype=dtype)
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
defines= {
'TYPE' : dtype,
'TM' : [32,64,128],
'TN' : [32,64,128],
'TK' : [8],
}
_dot.kernel = triton.kernel(_dot.src, defines=defines)
_dot.kernel(a, b, c, M, N, K, lda, ldb, ldc,
grid=grid, num_warps=4, defines=defines)
return c
dot = _dot.apply
torch.manual_seed(0)
M, N, K = 128, 512, 256
a = torch.rand((M, K)).cuda()
b = torch.rand((K, N)).cuda()
zc = torch.matmul(a,b)
zc_ = dot(a,b)
print(torch.allclose(zc, zc_))

View File

@@ -0,0 +1,43 @@
import torch
import triton
class _add(torch.autograd.Function):
src = """
__global__ void add(float* z, float* x, float* y, int N) {
int pid = get_program_id(0);
int offset[TILE] = pid * TILE + 0 ... TILE;
float* pz[TILE] = z + offset;
float* px[TILE] = x + offset;
float* py[TILE] = y + offset;
bool check[TILE] = offset < N;
*?(check)pz = *?(check)px + *?(check)py;
}
"""
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
@staticmethod
def forward(ctx, x, y):
z = torch.empty_like(x).cuda()
N = x.numel()
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
_add.kernel(z,x,y, N, grid=grid)
return z
add = _add.apply
# test
torch.manual_seed(0)
x = torch.rand(98432).cuda()
y = torch.rand(98432).cuda()
za = x + y
zb = add(x, y)
print(torch.allclose(za,zb))