[python] [op] added Triton NT einsum

This commit is contained in:
Philippe Tillet
2019-10-21 23:37:39 -04:00
parent 099918b3c0
commit 943bf41b5c
3 changed files with 44 additions and 81 deletions

View File

@@ -177,7 +177,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
std::unique_ptr<driver::module> bin;
try{
bin = make_bin(*ir, stream->context(), opt);
}catch(...){
}catch(const std::runtime_error& e){
return;
}
// kernel uses too much resources

View File

@@ -2,83 +2,6 @@ import numpy as np
import torch
import triton
class _dot(triton.function):
src = """
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
int sb, int sh, int sa, int sk, int sn) {
// program id
int pidx = get_program_id(0);
int pidy = get_program_id(1);
int pidz = get_program_id(2);
// ranges
int rxa[TM] = pidx * TM + 0 ... TM;
int ryb[TN] = pidy * TN + 0 ... TN;
int rza[TZ] = pidz * TZ + 0 ... TZ;
int rzb[TZ] = pidz * TZ + 0 ... TZ;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
// accumulator
float c[TM, TN, TZ] = 0;
// pointers to A
TYPE* pa[TM, TK, TZ] = A + rka[newaxis, :, newaxis] * 1 // reduction
+ rxa[:, newaxis, newaxis] * sk * sa * sh // outer
+ rza[newaxis, newaxis, :] * sk; // batch
// pointers to B
TYPE* pb[TK, TN, TZ] = B + rkb[:, newaxis, newaxis] * 1 // reduction
+ ryb[newaxis, :, newaxis] * sk // outer
+ rzb[newaxis, newaxis, :] * sk * sn; // batch
// reduction loop
for(int k = sk; k > 0; k -= TK){
TYPE a[TM, TK, TZ] = *pa;
TYPE b[TK, TN, TZ] = *pb;
c += a @ b;
pa += TK;
pb += TK;
}
// epilogue
int rxc[TM] = pidx * TM + 0 ... TM;
int ryc[TN] = pidy * TN + 0 ... TN;
int rzc[TZ] = pidz * TZ + 0 ... TZ;
TYPE* pc[TM, TN, TZ] = C + rxc[:, newaxis, newaxis] * sn * sa * sh // outer[0]
+ ryc[newaxis, :, newaxis] * 1 // outer[1]
+ rzc[newaxis, newaxis, :] * sn;
*pc = c;
}
"""
kernel = triton.kernel(src, ['C'])
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
B, H, A, K = shape_a[0], shape_a[1], shape_a[2], shape_a[3]
H, A, N, K = shape_b[0], shape_b[1], shape_b[2], shape_b[3]
# allocate output
dtype = a.dtype
c = triton.empty([B, H, A, N], dtype = dtype)
# SPMD grid
grid = lambda opt: [triton.cdiv(B, opt.d('TM')),
triton.cdiv(N, opt.d('TN')),
triton.cdiv(H*A, opt.d('TZ'))]
# launch kernel
return _dot.kernel(a, b, c, B, H, A, K, N, grid,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [32], TN = [32], TK = [8], TZ = [8])
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
return _dot._call(a, b, transpose_a, transpose_b)
dot = _dot.apply
batch_dim = 16
ctx_dim = 32
head_dim = 8
@@ -92,6 +15,7 @@ x_shape = (bs, state_dim)
qw_shape = (state_dim, head_dim * key_dim)
kw_shape = (head_dim, 2, n_keys, key_dim // 2)
np.random.seed(0)
x = np.random.uniform(-1.0, 1.0, x_shape).astype(np.float32) # layer input
qw = np.random.uniform(-1.0, 1.0, qw_shape).astype(np.float32) # query weights
kw = np.random.uniform(-1.0, 1.0, kw_shape).astype(np.float32) # key weights
@@ -108,6 +32,7 @@ qk = np.einsum("bhak,hank->bhan", q, kw)
tq = torch.from_numpy(q).contiguous().cuda()
tkw = torch.from_numpy(kw).contiguous().cuda()
tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw)
diff = qk - tqk.cpu().numpy()
diff = np.abs(qk - tqk.cpu().numpy())
print(np.max(diff))
print(np.min(diff))

View File

@@ -7,8 +7,43 @@ class _einsum(triton.function):
int dim_M, int dim_N, int dim_K,
int std_A0, int std_B0, int std_C0,
int std_A1, int std_B1, int std_C1) {
// program id
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pid2 = get_program_id(2);
// range
int rma[TM] = pid0 * TM + 0 ... TM;
int rnb[TN] = pid1 * TN + 0 ... TN;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
int rba[TB] = pid2 * TB + 0 ... TB;
int rbb[TB] = pid2 * TB + 0 ... TB;
// accumulator
TYPE c[TM, TN, TB] = 0;
// pointers to a
TYPE *pa[TM, TK, TB] = A + rka[newaxis, :, newaxis] * 1
+ rma[:, newaxis, newaxis] * std_A1
+ rba[newaxis, newaxis, :] * std_A0;
// pointers to b
TYPE *pb[TK, TN, TB] = B + rkb[:, newaxis, newaxis] * 1
+ rnb[newaxis, :, newaxis] * std_B1
+ rbb[newaxis, newaxis, :] * std_B0;
// accumulation
for(int k = dim_K; k > 0; k -= TK) {
TYPE a[TM, TK, TB] = *pa;
TYPE b[TK, TN, TB] = *pb;
c += a @ b;
pa += TK;
pb += TK;
}
// write-back
int rmc[TM] = pid0 * TM + 0 ... TM;
int rnc[TN] = pid1 * TN + 0 ... TN;
int rbc[TB] = pid2 * TB + 0 ... TB;
TYPE *pc[TM, TN, TB] = C + rmc[:, newaxis, newaxis] * std_C1
+ rnc[newaxis, :, newaxis] * 1
+ rbc[newaxis, newaxis, :] * std_C0;
*pc = c;
}
"""
@@ -100,13 +135,16 @@ class _einsum(triton.function):
std0, std1, einsum_a, einsum_b, einsum_c):
dtype = a.dtype
c = triton.empty(shape_c, dtype)
grid = lambda opt: (1, 1, 1)
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
triton.cdiv(bmnk[2], opt.d('TN')),
triton.cdiv(bmnk[0], opt.d('TB'))]
#print(std0, std1)
return _einsum.kernel(a, b, c,
bmnk[1], bmnk[2], bmnk[3],
std0[0], std0[1], std0[2],
std1[0], std1[1], std1[2],
grid,
TYPE=['float'])
TYPE='float', TM=32, TN=32, TK=8, TB=8)
@staticmethod