[python] [op] added Triton NT einsum
This commit is contained in:
@@ -177,7 +177,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
||||
std::unique_ptr<driver::module> bin;
|
||||
try{
|
||||
bin = make_bin(*ir, stream->context(), opt);
|
||||
}catch(...){
|
||||
}catch(const std::runtime_error& e){
|
||||
return;
|
||||
}
|
||||
// kernel uses too much resources
|
||||
|
@@ -2,83 +2,6 @@ import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
|
||||
class _dot(triton.function):
|
||||
|
||||
src = """
|
||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int sb, int sh, int sa, int sk, int sn) {
|
||||
// program id
|
||||
int pidx = get_program_id(0);
|
||||
int pidy = get_program_id(1);
|
||||
int pidz = get_program_id(2);
|
||||
// ranges
|
||||
int rxa[TM] = pidx * TM + 0 ... TM;
|
||||
int ryb[TN] = pidy * TN + 0 ... TN;
|
||||
int rza[TZ] = pidz * TZ + 0 ... TZ;
|
||||
int rzb[TZ] = pidz * TZ + 0 ... TZ;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
// accumulator
|
||||
float c[TM, TN, TZ] = 0;
|
||||
// pointers to A
|
||||
TYPE* pa[TM, TK, TZ] = A + rka[newaxis, :, newaxis] * 1 // reduction
|
||||
+ rxa[:, newaxis, newaxis] * sk * sa * sh // outer
|
||||
+ rza[newaxis, newaxis, :] * sk; // batch
|
||||
// pointers to B
|
||||
TYPE* pb[TK, TN, TZ] = B + rkb[:, newaxis, newaxis] * 1 // reduction
|
||||
+ ryb[newaxis, :, newaxis] * sk // outer
|
||||
+ rzb[newaxis, newaxis, :] * sk * sn; // batch
|
||||
// reduction loop
|
||||
for(int k = sk; k > 0; k -= TK){
|
||||
TYPE a[TM, TK, TZ] = *pa;
|
||||
TYPE b[TK, TN, TZ] = *pb;
|
||||
c += a @ b;
|
||||
pa += TK;
|
||||
pb += TK;
|
||||
}
|
||||
// epilogue
|
||||
int rxc[TM] = pidx * TM + 0 ... TM;
|
||||
int ryc[TN] = pidy * TN + 0 ... TN;
|
||||
int rzc[TZ] = pidz * TZ + 0 ... TZ;
|
||||
TYPE* pc[TM, TN, TZ] = C + rxc[:, newaxis, newaxis] * sn * sa * sh // outer[0]
|
||||
+ ryc[newaxis, :, newaxis] * 1 // outer[1]
|
||||
+ rzc[newaxis, newaxis, :] * sn;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
kernel = triton.kernel(src, ['C'])
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, transpose_a, transpose_b):
|
||||
# extract shapes
|
||||
shape_a = triton.shape(a)
|
||||
shape_b = triton.shape(b)
|
||||
B, H, A, K = shape_a[0], shape_a[1], shape_a[2], shape_a[3]
|
||||
H, A, N, K = shape_b[0], shape_b[1], shape_b[2], shape_b[3]
|
||||
# allocate output
|
||||
dtype = a.dtype
|
||||
c = triton.empty([B, H, A, N], dtype = dtype)
|
||||
# SPMD grid
|
||||
grid = lambda opt: [triton.cdiv(B, opt.d('TM')),
|
||||
triton.cdiv(N, opt.d('TN')),
|
||||
triton.cdiv(H*A, opt.d('TZ'))]
|
||||
# launch kernel
|
||||
return _dot.kernel(a, b, c, B, H, A, K, N, grid,
|
||||
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
||||
TM = [32], TN = [32], TK = [8], TZ = [8])
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.t_a = transpose_a
|
||||
ctx.t_b = transpose_b
|
||||
return _dot._call(a, b, transpose_a, transpose_b)
|
||||
|
||||
|
||||
dot = _dot.apply
|
||||
|
||||
|
||||
batch_dim = 16
|
||||
ctx_dim = 32
|
||||
head_dim = 8
|
||||
@@ -92,6 +15,7 @@ x_shape = (bs, state_dim)
|
||||
qw_shape = (state_dim, head_dim * key_dim)
|
||||
kw_shape = (head_dim, 2, n_keys, key_dim // 2)
|
||||
|
||||
np.random.seed(0)
|
||||
x = np.random.uniform(-1.0, 1.0, x_shape).astype(np.float32) # layer input
|
||||
qw = np.random.uniform(-1.0, 1.0, qw_shape).astype(np.float32) # query weights
|
||||
kw = np.random.uniform(-1.0, 1.0, kw_shape).astype(np.float32) # key weights
|
||||
@@ -108,6 +32,7 @@ qk = np.einsum("bhak,hank->bhan", q, kw)
|
||||
tq = torch.from_numpy(q).contiguous().cuda()
|
||||
tkw = torch.from_numpy(kw).contiguous().cuda()
|
||||
tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw)
|
||||
diff = qk - tqk.cpu().numpy()
|
||||
diff = np.abs(qk - tqk.cpu().numpy())
|
||||
print(np.max(diff))
|
||||
print(np.min(diff))
|
||||
|
||||
|
@@ -7,8 +7,43 @@ class _einsum(triton.function):
|
||||
int dim_M, int dim_N, int dim_K,
|
||||
int std_A0, int std_B0, int std_C0,
|
||||
int std_A1, int std_B1, int std_C1) {
|
||||
// program id
|
||||
int pid0 = get_program_id(0);
|
||||
int pid1 = get_program_id(1);
|
||||
int pid2 = get_program_id(2);
|
||||
// range
|
||||
int rma[TM] = pid0 * TM + 0 ... TM;
|
||||
int rnb[TN] = pid1 * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
int rba[TB] = pid2 * TB + 0 ... TB;
|
||||
int rbb[TB] = pid2 * TB + 0 ... TB;
|
||||
// accumulator
|
||||
TYPE c[TM, TN, TB] = 0;
|
||||
// pointers to a
|
||||
TYPE *pa[TM, TK, TB] = A + rka[newaxis, :, newaxis] * 1
|
||||
+ rma[:, newaxis, newaxis] * std_A1
|
||||
+ rba[newaxis, newaxis, :] * std_A0;
|
||||
// pointers to b
|
||||
TYPE *pb[TK, TN, TB] = B + rkb[:, newaxis, newaxis] * 1
|
||||
+ rnb[newaxis, :, newaxis] * std_B1
|
||||
+ rbb[newaxis, newaxis, :] * std_B0;
|
||||
// accumulation
|
||||
for(int k = dim_K; k > 0; k -= TK) {
|
||||
TYPE a[TM, TK, TB] = *pa;
|
||||
TYPE b[TK, TN, TB] = *pb;
|
||||
c += a @ b;
|
||||
pa += TK;
|
||||
pb += TK;
|
||||
}
|
||||
// write-back
|
||||
int rmc[TM] = pid0 * TM + 0 ... TM;
|
||||
int rnc[TN] = pid1 * TN + 0 ... TN;
|
||||
int rbc[TB] = pid2 * TB + 0 ... TB;
|
||||
TYPE *pc[TM, TN, TB] = C + rmc[:, newaxis, newaxis] * std_C1
|
||||
+ rnc[newaxis, :, newaxis] * 1
|
||||
+ rbc[newaxis, newaxis, :] * std_C0;
|
||||
*pc = c;
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -100,13 +135,16 @@ class _einsum(triton.function):
|
||||
std0, std1, einsum_a, einsum_b, einsum_c):
|
||||
dtype = a.dtype
|
||||
c = triton.empty(shape_c, dtype)
|
||||
grid = lambda opt: (1, 1, 1)
|
||||
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
||||
triton.cdiv(bmnk[2], opt.d('TN')),
|
||||
triton.cdiv(bmnk[0], opt.d('TB'))]
|
||||
#print(std0, std1)
|
||||
return _einsum.kernel(a, b, c,
|
||||
bmnk[1], bmnk[2], bmnk[3],
|
||||
std0[0], std0[1], std0[2],
|
||||
std1[0], std1[1], std1[2],
|
||||
grid,
|
||||
TYPE=['float'])
|
||||
TYPE='float', TM=32, TN=32, TK=8, TB=8)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user