[PYTHON][EINSUM] Now handling reduction sizes that are not a multiple of
TK
This commit is contained in:
committed by
Philippe Tillet
parent
fa4ec7ea65
commit
3816f2f259
@@ -369,7 +369,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
|
||||
Value *result_false = false_values->get_value(idx);
|
||||
if(result_then->getType()->isVectorTy())
|
||||
result_false = builder_->CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType()));
|
||||
result_false = builder_->CreateVectorSplat(vector_size, result_false);
|
||||
((PHINode*)current_result)->addIncoming(result_false, current_bb);
|
||||
}
|
||||
else
|
||||
|
@@ -255,6 +255,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
isel.visit(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -5,6 +5,8 @@ import numpy as np
|
||||
#import utils
|
||||
from time import time
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
#torch.backends.cudnn.benchmark = True
|
||||
|
||||
configs = []
|
||||
@@ -31,15 +33,15 @@ MNK = [
|
||||
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b)
|
||||
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a.t(), b)
|
||||
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
|
||||
#for M, N, K in MNK:
|
||||
# matmul = lambda a, b: torch.matmul(a, b.t())
|
||||
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
|
||||
|
||||
# Relative attention
|
||||
NTHSE = [
|
||||
@@ -70,16 +72,16 @@ NTHSE = [
|
||||
# (128, 1024, 8, 256, 256),
|
||||
#(128, 1024, 8, 256, 512)
|
||||
]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
|
||||
#for N, T, H, S, E in NTHSE:
|
||||
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
|
||||
|
||||
# 1D Dense convolution
|
||||
NCHKR = [
|
||||
(1, 1152, 12602, 512, 3)
|
||||
#(1, 1152, 12602, 512, 3)
|
||||
]
|
||||
for N, C, H, K, R in NCHKR:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
|
||||
@@ -93,17 +95,17 @@ for N, C, H, K, R in NCHKR:
|
||||
# 2D Dense convolution
|
||||
NCHWKRS = [
|
||||
#(8, 64, 128, 128, 768, 3, 3),
|
||||
(8, 128, 64, 64, 256, 3, 3),
|
||||
(8, 256, 32, 32, 512, 3, 3),
|
||||
(128, 3, 32, 32, 64, 3, 3),
|
||||
#(8, 256, 32, 32, 512, 3, 3),
|
||||
#(8, 512, 32, 32, 1024, 3, 3)
|
||||
]
|
||||
for N, C, H, W, K, R, S in NCHWKRS:
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
|
||||
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b)
|
||||
configs += [([N, C, H, W],
|
||||
[C, R, S, K],
|
||||
[K, C, R, S],
|
||||
[N, K, H - R + 1, W - R + 1],
|
||||
torch_fn,
|
||||
'nc(h+r)(w+s),crsk->nkhw',
|
||||
'nc(h+r)(w+s),kcrs->nkhw',
|
||||
dict())]
|
||||
|
||||
# 3D Dense Convolution
|
||||
@@ -173,6 +175,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
print(a.size(), b.size())
|
||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
@@ -182,7 +185,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ctx_registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
cmp_eqbmm = True
|
||||
cmp_eqbmm = False
|
||||
if cmp_eqbmm:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
|
@@ -210,7 +210,7 @@ class kernel:
|
||||
macros.append((k, values))
|
||||
opt = libtriton.options_space()
|
||||
opt.defines = macros
|
||||
opt.num_warps = [2, 4, 8]
|
||||
opt.num_warps = num_warps
|
||||
# create unique id for this op
|
||||
op_id = libtriton.make_op_id()
|
||||
self.fw_id[key] = op_id
|
||||
|
@@ -110,13 +110,16 @@ __global__ void {name}(
|
||||
src += "\n "
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_a_inner __multipleof({multipleof_a})"
|
||||
src += f", int rem_delta_a __multipleof({multipleof_a})"
|
||||
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* AD __noalias __readonly __aligned(16)"
|
||||
src += "\n "
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += f", int stride_b_inner __multipleof({multipleof_b})"
|
||||
src += f", int rem_delta_b __multipleof({multipleof_b})"
|
||||
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
src += ", int* BD"
|
||||
src += "\n"
|
||||
for ptr in subscripted:
|
||||
src += f", int* {ptr}"
|
||||
src += """) {
|
||||
@@ -142,6 +145,7 @@ __global__ void {name}(
|
||||
int off_k = pid_z * div_z;
|
||||
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
|
||||
#endif
|
||||
int rem_k = matmul_k % TK;
|
||||
|
||||
// create ranges
|
||||
"""
|
||||
@@ -204,13 +208,33 @@ __global__ void {name}(
|
||||
src += f"""
|
||||
|
||||
// prefetch
|
||||
int prefetch_k = select(rem_k > 0, rem_k, TK);
|
||||
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
|
||||
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
|
||||
bool checkk[TK] = {rk} < matmul_k + off_k;
|
||||
bool checkk[TK] = {rk} < prefetch_k;
|
||||
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||
TYPE a[TM, TK, TB] = checka ? *pa : 0;
|
||||
TYPE b[TK, TN, TB] = checkb ? *pb : 0;
|
||||
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
|
||||
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pa += rem_delta_a;"""
|
||||
else:
|
||||
src += """
|
||||
pa += incda;
|
||||
padelta += TK;
|
||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pb += rem_delta_b;"""
|
||||
else:
|
||||
src += """
|
||||
pb += (*pbdelta)[:, newaxis, newaxis];
|
||||
pbdelta += TK;"""
|
||||
|
||||
src += f"""
|
||||
// accumulate
|
||||
float acc[TM, TN, TB] = 0;
|
||||
for(int k = matmul_k; k > 0; k -= TK) {{
|
||||
@@ -219,16 +243,13 @@ __global__ void {name}(
|
||||
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
|
||||
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
|
||||
#endif
|
||||
"""
|
||||
|
||||
if not use_lut_a or not use_lut_b:
|
||||
src += f"""
|
||||
{rk} += TK;
|
||||
"""
|
||||
src += _einsum.unpack_cc(tile, axes_k, 'r', True)
|
||||
checkk = k > TK;
|
||||
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;"""
|
||||
|
||||
|
||||
if use_lut_a:
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
src += """
|
||||
pa += stride_a_inner;"""
|
||||
@@ -237,18 +258,6 @@ __global__ void {name}(
|
||||
pa += incda;
|
||||
padelta += TK;
|
||||
incda = (*padelta)[newaxis, :, newaxis];"""
|
||||
else:
|
||||
src += """
|
||||
offa = """
|
||||
for i, sym in enumerate(expr_a):
|
||||
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
|
||||
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
|
||||
if i > 0:
|
||||
src += ' + '
|
||||
src += f"({ccode}) * {stride}\n "
|
||||
src += """;
|
||||
TYPE *pa[TM, TK, TB] = A + offa;"""
|
||||
|
||||
|
||||
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
@@ -260,11 +269,6 @@ __global__ void {name}(
|
||||
pbdelta += TK;"""
|
||||
|
||||
src += f"""
|
||||
checkk = k > TK;
|
||||
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
|
||||
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
|
||||
a = *?(checka)pa;
|
||||
b = *?(checkb)pb;
|
||||
}}
|
||||
TYPE c[TM, TN, TB] = acc;
|
||||
|
||||
@@ -367,16 +371,27 @@ __global__ void {name}(
|
||||
fn = sp.lambdify(args, delta, 'numpy')
|
||||
# inner axes values
|
||||
inner = [dims[d] for d in axes]
|
||||
k = np.arange(np.prod(inner), dtype=np.int32)
|
||||
inner = np.prod(inner)
|
||||
rem = inner % step
|
||||
rem = rem if rem > 0 else step
|
||||
# k = [0, 1, ..., step,
|
||||
# rem, rem + 1, ... rem + inner]
|
||||
k = np.concatenate((np.arange(step),
|
||||
np.arange(rem, inner))).astype(np.int32)
|
||||
# nextk = [rem, 1 + rem, ..., step + rem,
|
||||
# rem + step, rem + 1 + step, ..., inner + step]
|
||||
nextk = np.concatenate((k[:step] + rem,
|
||||
k[step:] + step))
|
||||
# offsets
|
||||
off = _einsum.unpack_offset(k, axes, dims)
|
||||
nextoff = _einsum.unpack_offset(k + step, axes, dims)
|
||||
nextoff = _einsum.unpack_offset(nextk, axes, dims)
|
||||
# evaluate deltas
|
||||
args = [s for s in stride]
|
||||
args += [off[sk] for sk in axes]
|
||||
args += [nextoff[sk] for sk in axes]
|
||||
args += [x for _, x in arrays]
|
||||
delta = fn(*args)
|
||||
return delta, _einsum.lut_mode(delta[:-step])
|
||||
return delta, _einsum.lut_mode(delta[step:-step])
|
||||
|
||||
############################
|
||||
## Einsum parsing
|
||||
@@ -525,12 +540,17 @@ __global__ void {name}(
|
||||
alpha, M, N, K, div_m] +\
|
||||
dim_m + dim_n + dim_k + dim_b +\
|
||||
stride_a + stride_b + stride_c
|
||||
if lut_mode_a != _einsum.LUT_MODE.CONSTANT:
|
||||
delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda()
|
||||
self.args += [delta_a]
|
||||
if lut_mode_b != _einsum.LUT_MODE.CONSTANT:
|
||||
delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
|
||||
self.args += [delta_b]
|
||||
# LUT for A
|
||||
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
|
||||
self.args += [delta_a[TK], delta_a[0]]
|
||||
if lut_mode_a == _einsum.LUT_MODE.DRAM:
|
||||
self.args += [torch.from_numpy(delta_a).cuda()]
|
||||
# LUT for B
|
||||
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
|
||||
self.args += [delta_b[TK], delta_b[0]]
|
||||
if lut_mode_b == _einsum.LUT_MODE.DRAM:
|
||||
self.args += [torch.from_numpy(delta_b).cuda()]
|
||||
# Einsum dependents
|
||||
self.args += arrays
|
||||
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
|
||||
triton.cdiv(N, opt.d('TN')),
|
||||
@@ -551,6 +571,7 @@ __global__ void {name}(
|
||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||
TM, TN, TB, TZ = 64, 64, 1, 1
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
if mask:
|
||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
@@ -613,6 +634,7 @@ __global__ void {name}(
|
||||
ctx.matmul_K = instance.matmul_K
|
||||
ctx.bench = bench
|
||||
ctx.forward_ms = speed
|
||||
ctx.mask = mask
|
||||
ctx.save_for_backward(a, b)
|
||||
return c
|
||||
|
||||
@@ -621,7 +643,7 @@ __global__ void {name}(
|
||||
############################
|
||||
|
||||
@staticmethod
|
||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask):
|
||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
|
||||
for i, expr in enumerate(sym_x):
|
||||
if expr.is_symbol:
|
||||
continue
|
||||
@@ -652,9 +674,9 @@ __global__ void {name}(
|
||||
expr_a = _einsum.sym_to_expr(sym_a)
|
||||
expr_b = _einsum.sym_to_expr(sym_b)
|
||||
expr_c = _einsum.sym_to_expr(sym_c)
|
||||
expr = f'{expr_c},{expr_b}->{expr_a}'
|
||||
da = einsum(expr, dy, b, a.shape, False)
|
||||
return None, da, None, None, None
|
||||
da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask)
|
||||
db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask)
|
||||
return None, da, db, None, None
|
||||
|
||||
|
||||
einsum = _einsum.apply
|
Reference in New Issue
Block a user