[python] [op] added Triton NT einsum
This commit is contained in:
@@ -177,7 +177,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr
|
|||||||
std::unique_ptr<driver::module> bin;
|
std::unique_ptr<driver::module> bin;
|
||||||
try{
|
try{
|
||||||
bin = make_bin(*ir, stream->context(), opt);
|
bin = make_bin(*ir, stream->context(), opt);
|
||||||
}catch(...){
|
}catch(const std::runtime_error& e){
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// kernel uses too much resources
|
// kernel uses too much resources
|
||||||
|
@@ -2,83 +2,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
class _dot(triton.function):
|
|
||||||
|
|
||||||
src = """
|
|
||||||
__global__ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|
||||||
int sb, int sh, int sa, int sk, int sn) {
|
|
||||||
// program id
|
|
||||||
int pidx = get_program_id(0);
|
|
||||||
int pidy = get_program_id(1);
|
|
||||||
int pidz = get_program_id(2);
|
|
||||||
// ranges
|
|
||||||
int rxa[TM] = pidx * TM + 0 ... TM;
|
|
||||||
int ryb[TN] = pidy * TN + 0 ... TN;
|
|
||||||
int rza[TZ] = pidz * TZ + 0 ... TZ;
|
|
||||||
int rzb[TZ] = pidz * TZ + 0 ... TZ;
|
|
||||||
int rka[TK] = 0 ... TK;
|
|
||||||
int rkb[TK] = 0 ... TK;
|
|
||||||
// accumulator
|
|
||||||
float c[TM, TN, TZ] = 0;
|
|
||||||
// pointers to A
|
|
||||||
TYPE* pa[TM, TK, TZ] = A + rka[newaxis, :, newaxis] * 1 // reduction
|
|
||||||
+ rxa[:, newaxis, newaxis] * sk * sa * sh // outer
|
|
||||||
+ rza[newaxis, newaxis, :] * sk; // batch
|
|
||||||
// pointers to B
|
|
||||||
TYPE* pb[TK, TN, TZ] = B + rkb[:, newaxis, newaxis] * 1 // reduction
|
|
||||||
+ ryb[newaxis, :, newaxis] * sk // outer
|
|
||||||
+ rzb[newaxis, newaxis, :] * sk * sn; // batch
|
|
||||||
// reduction loop
|
|
||||||
for(int k = sk; k > 0; k -= TK){
|
|
||||||
TYPE a[TM, TK, TZ] = *pa;
|
|
||||||
TYPE b[TK, TN, TZ] = *pb;
|
|
||||||
c += a @ b;
|
|
||||||
pa += TK;
|
|
||||||
pb += TK;
|
|
||||||
}
|
|
||||||
// epilogue
|
|
||||||
int rxc[TM] = pidx * TM + 0 ... TM;
|
|
||||||
int ryc[TN] = pidy * TN + 0 ... TN;
|
|
||||||
int rzc[TZ] = pidz * TZ + 0 ... TZ;
|
|
||||||
TYPE* pc[TM, TN, TZ] = C + rxc[:, newaxis, newaxis] * sn * sa * sh // outer[0]
|
|
||||||
+ ryc[newaxis, :, newaxis] * 1 // outer[1]
|
|
||||||
+ rzc[newaxis, newaxis, :] * sn;
|
|
||||||
*pc = c;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = triton.kernel(src, ['C'])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _call(a, b, transpose_a, transpose_b):
|
|
||||||
# extract shapes
|
|
||||||
shape_a = triton.shape(a)
|
|
||||||
shape_b = triton.shape(b)
|
|
||||||
B, H, A, K = shape_a[0], shape_a[1], shape_a[2], shape_a[3]
|
|
||||||
H, A, N, K = shape_b[0], shape_b[1], shape_b[2], shape_b[3]
|
|
||||||
# allocate output
|
|
||||||
dtype = a.dtype
|
|
||||||
c = triton.empty([B, H, A, N], dtype = dtype)
|
|
||||||
# SPMD grid
|
|
||||||
grid = lambda opt: [triton.cdiv(B, opt.d('TM')),
|
|
||||||
triton.cdiv(N, opt.d('TN')),
|
|
||||||
triton.cdiv(H*A, opt.d('TZ'))]
|
|
||||||
# launch kernel
|
|
||||||
return _dot.kernel(a, b, c, B, H, A, K, N, grid,
|
|
||||||
AT = transpose_a, BT = transpose_b, TYPE = dtype,
|
|
||||||
TM = [32], TN = [32], TK = [8], TZ = [8])
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
|
|
||||||
ctx.save_for_backward(a, b)
|
|
||||||
ctx.t_a = transpose_a
|
|
||||||
ctx.t_b = transpose_b
|
|
||||||
return _dot._call(a, b, transpose_a, transpose_b)
|
|
||||||
|
|
||||||
|
|
||||||
dot = _dot.apply
|
|
||||||
|
|
||||||
|
|
||||||
batch_dim = 16
|
batch_dim = 16
|
||||||
ctx_dim = 32
|
ctx_dim = 32
|
||||||
head_dim = 8
|
head_dim = 8
|
||||||
@@ -92,6 +15,7 @@ x_shape = (bs, state_dim)
|
|||||||
qw_shape = (state_dim, head_dim * key_dim)
|
qw_shape = (state_dim, head_dim * key_dim)
|
||||||
kw_shape = (head_dim, 2, n_keys, key_dim // 2)
|
kw_shape = (head_dim, 2, n_keys, key_dim // 2)
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
x = np.random.uniform(-1.0, 1.0, x_shape).astype(np.float32) # layer input
|
x = np.random.uniform(-1.0, 1.0, x_shape).astype(np.float32) # layer input
|
||||||
qw = np.random.uniform(-1.0, 1.0, qw_shape).astype(np.float32) # query weights
|
qw = np.random.uniform(-1.0, 1.0, qw_shape).astype(np.float32) # query weights
|
||||||
kw = np.random.uniform(-1.0, 1.0, kw_shape).astype(np.float32) # key weights
|
kw = np.random.uniform(-1.0, 1.0, kw_shape).astype(np.float32) # key weights
|
||||||
@@ -108,6 +32,7 @@ qk = np.einsum("bhak,hank->bhan", q, kw)
|
|||||||
tq = torch.from_numpy(q).contiguous().cuda()
|
tq = torch.from_numpy(q).contiguous().cuda()
|
||||||
tkw = torch.from_numpy(kw).contiguous().cuda()
|
tkw = torch.from_numpy(kw).contiguous().cuda()
|
||||||
tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw)
|
tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw)
|
||||||
diff = qk - tqk.cpu().numpy()
|
diff = np.abs(qk - tqk.cpu().numpy())
|
||||||
print(np.max(diff))
|
print(np.max(diff))
|
||||||
|
print(np.min(diff))
|
||||||
|
|
||||||
|
@@ -7,8 +7,43 @@ class _einsum(triton.function):
|
|||||||
int dim_M, int dim_N, int dim_K,
|
int dim_M, int dim_N, int dim_K,
|
||||||
int std_A0, int std_B0, int std_C0,
|
int std_A0, int std_B0, int std_C0,
|
||||||
int std_A1, int std_B1, int std_C1) {
|
int std_A1, int std_B1, int std_C1) {
|
||||||
|
// program id
|
||||||
int pid0 = get_program_id(0);
|
int pid0 = get_program_id(0);
|
||||||
int pid1 = get_program_id(1);
|
int pid1 = get_program_id(1);
|
||||||
|
int pid2 = get_program_id(2);
|
||||||
|
// range
|
||||||
|
int rma[TM] = pid0 * TM + 0 ... TM;
|
||||||
|
int rnb[TN] = pid1 * TN + 0 ... TN;
|
||||||
|
int rka[TK] = 0 ... TK;
|
||||||
|
int rkb[TK] = 0 ... TK;
|
||||||
|
int rba[TB] = pid2 * TB + 0 ... TB;
|
||||||
|
int rbb[TB] = pid2 * TB + 0 ... TB;
|
||||||
|
// accumulator
|
||||||
|
TYPE c[TM, TN, TB] = 0;
|
||||||
|
// pointers to a
|
||||||
|
TYPE *pa[TM, TK, TB] = A + rka[newaxis, :, newaxis] * 1
|
||||||
|
+ rma[:, newaxis, newaxis] * std_A1
|
||||||
|
+ rba[newaxis, newaxis, :] * std_A0;
|
||||||
|
// pointers to b
|
||||||
|
TYPE *pb[TK, TN, TB] = B + rkb[:, newaxis, newaxis] * 1
|
||||||
|
+ rnb[newaxis, :, newaxis] * std_B1
|
||||||
|
+ rbb[newaxis, newaxis, :] * std_B0;
|
||||||
|
// accumulation
|
||||||
|
for(int k = dim_K; k > 0; k -= TK) {
|
||||||
|
TYPE a[TM, TK, TB] = *pa;
|
||||||
|
TYPE b[TK, TN, TB] = *pb;
|
||||||
|
c += a @ b;
|
||||||
|
pa += TK;
|
||||||
|
pb += TK;
|
||||||
|
}
|
||||||
|
// write-back
|
||||||
|
int rmc[TM] = pid0 * TM + 0 ... TM;
|
||||||
|
int rnc[TN] = pid1 * TN + 0 ... TN;
|
||||||
|
int rbc[TB] = pid2 * TB + 0 ... TB;
|
||||||
|
TYPE *pc[TM, TN, TB] = C + rmc[:, newaxis, newaxis] * std_C1
|
||||||
|
+ rnc[newaxis, :, newaxis] * 1
|
||||||
|
+ rbc[newaxis, newaxis, :] * std_C0;
|
||||||
|
*pc = c;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -100,13 +135,16 @@ class _einsum(triton.function):
|
|||||||
std0, std1, einsum_a, einsum_b, einsum_c):
|
std0, std1, einsum_a, einsum_b, einsum_c):
|
||||||
dtype = a.dtype
|
dtype = a.dtype
|
||||||
c = triton.empty(shape_c, dtype)
|
c = triton.empty(shape_c, dtype)
|
||||||
grid = lambda opt: (1, 1, 1)
|
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
||||||
|
triton.cdiv(bmnk[2], opt.d('TN')),
|
||||||
|
triton.cdiv(bmnk[0], opt.d('TB'))]
|
||||||
|
#print(std0, std1)
|
||||||
return _einsum.kernel(a, b, c,
|
return _einsum.kernel(a, b, c,
|
||||||
bmnk[1], bmnk[2], bmnk[3],
|
bmnk[1], bmnk[2], bmnk[3],
|
||||||
std0[0], std0[1], std0[2],
|
std0[0], std0[1], std0[2],
|
||||||
std1[0], std1[1], std1[2],
|
std1[0], std1[1], std1[2],
|
||||||
grid,
|
grid,
|
||||||
TYPE=['float'])
|
TYPE='float', TM=32, TN=32, TK=8, TB=8)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
Reference in New Issue
Block a user