[DOCS] Matmul and vecadd working examples
This commit is contained in:
committed by
Philippe Tillet
parent
ce4a4728f5
commit
32819dea51
84
python/examples/tutorials/mat_mul.py
Normal file
84
python/examples/tutorials/mat_mul.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _dot(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void dot(TYPE *A, TYPE *B, TYPE *C, int M, int N, int K,
|
||||
int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) {
|
||||
int pm = get_program_id(0);
|
||||
int pn = get_program_id(1);
|
||||
|
||||
// ranges
|
||||
int rm[TM] = pm * TM + 0 ... TM;
|
||||
int rn[TN] = pn * TN + 0 ... TN;
|
||||
int rk[TK] = 0 ... TK;
|
||||
|
||||
// accumulator
|
||||
float c[TM, TN] = 0;
|
||||
|
||||
//pointers
|
||||
TYPE* pa[TM, TK] = A + rk[newaxis, :] * 1 + rm[:, newaxis] * lda;
|
||||
TYPE* pb[TK, TN] = B + rk[:, newaxis] * ldb + rn[newaxis, :] * 1;
|
||||
|
||||
for(int k=K; k>0; k-=TK) {
|
||||
TYPE a[TM, TK] = *pa;
|
||||
TYPE b[TK, TN] = *pb;
|
||||
|
||||
c += a @ b;
|
||||
|
||||
pa = pa + TK * 1;
|
||||
pb = pb + TK * ldb;
|
||||
}
|
||||
|
||||
TYPE* pc[TM,TN] = C + rn[newaxis, :] + rm[:,newaxis] * ldc;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
c = _dot._call(a,b)
|
||||
return c
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b):
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
|
||||
lda = K
|
||||
ldb = N
|
||||
ldc = N
|
||||
|
||||
dtype = a.dtype
|
||||
|
||||
c = triton.empty([M,N], dtype=dtype)
|
||||
|
||||
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
|
||||
|
||||
defines= {
|
||||
'TYPE' : dtype,
|
||||
'TM' : [32,64,128],
|
||||
'TN' : [32,64,128],
|
||||
'TK' : [8],
|
||||
}
|
||||
|
||||
_dot.kernel = triton.kernel(_dot.src, defines=defines)
|
||||
_dot.kernel(a, b, c, M, N, K, lda, ldb, ldc,
|
||||
grid=grid, num_warps=4, defines=defines)
|
||||
return c
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
M, N, K = 128, 512, 256
|
||||
a = torch.rand((M, K)).cuda()
|
||||
b = torch.rand((K, N)).cuda()
|
||||
|
||||
|
||||
zc = torch.matmul(a,b)
|
||||
zc_ = dot(a,b)
|
||||
|
||||
print(torch.allclose(zc, zc_))
|
43
python/examples/tutorials/vec_add.py
Normal file
43
python/examples/tutorials/vec_add.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _add(torch.autograd.Function):
|
||||
src = """
|
||||
__global__ void add(float* z, float* x, float* y, int N) {
|
||||
|
||||
int pid = get_program_id(0);
|
||||
|
||||
int offset[TILE] = pid * TILE + 0 ... TILE;
|
||||
float* pz[TILE] = z + offset;
|
||||
float* px[TILE] = x + offset;
|
||||
float* py[TILE] = y + offset;
|
||||
|
||||
bool check[TILE] = offset < N;
|
||||
|
||||
*?(check)pz = *?(check)px + *?(check)py;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src, defines={'TILE': 1024}, num_warps=[4])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
z = torch.empty_like(x).cuda()
|
||||
|
||||
N = x.numel()
|
||||
grid = lambda opt: (triton.cdiv(N, opt.d('TILE')),)
|
||||
|
||||
_add.kernel(z,x,y, N, grid=grid)
|
||||
|
||||
return z
|
||||
|
||||
add = _add.apply
|
||||
|
||||
# test
|
||||
torch.manual_seed(0)
|
||||
x = torch.rand(98432).cuda()
|
||||
y = torch.rand(98432).cuda()
|
||||
za = x + y
|
||||
zb = add(x, y)
|
||||
|
||||
print(torch.allclose(za,zb))
|
Reference in New Issue
Block a user