[PYTHON][EINSUM] Now handling reduction sizes that are not a multiple of

TK
This commit is contained in:
Philippe Tillet
2020-02-14 12:41:47 -05:00
committed by Philippe Tillet
parent fa4ec7ea65
commit 3816f2f259
5 changed files with 99 additions and 73 deletions

View File

@@ -369,7 +369,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
((PHINode*)current_result)->addIncoming(result_then, mask_then_bb);
Value *result_false = false_values->get_value(idx);
if(result_then->getType()->isVectorTy())
result_false = builder_->CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType()));
result_false = builder_->CreateVectorSplat(vector_size, result_false);
((PHINode*)current_result)->addIncoming(result_false, current_bb);
}
else

View File

@@ -255,6 +255,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
barriers.run(module);
// ir::print(module, std::cout);
isel.visit(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));

View File

@@ -5,6 +5,8 @@ import numpy as np
#import utils
from time import time
torch.manual_seed(0)
#torch.backends.cudnn.benchmark = True
configs = []
@@ -31,15 +33,15 @@ MNK = [
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a.t(), b)
configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b.t())
configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b)
# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a.t(), b)
# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())]
#for M, N, K in MNK:
# matmul = lambda a, b: torch.matmul(a, b.t())
# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())]
# Relative attention
NTHSE = [
@@ -70,16 +72,16 @@ NTHSE = [
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
for N, T, H, S, E in NTHSE:
configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
#for N, T, H, S, E in NTHSE:
# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())]
#for N, T, H, S, E in NTHSE:
# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())]
# 1D Dense convolution
NCHKR = [
(1, 1152, 12602, 512, 3)
#(1, 1152, 12602, 512, 3)
]
for N, C, H, K, R in NCHKR:
torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1))
@@ -93,17 +95,17 @@ for N, C, H, K, R in NCHKR:
# 2D Dense convolution
NCHWKRS = [
#(8, 64, 128, 128, 768, 3, 3),
(8, 128, 64, 64, 256, 3, 3),
(8, 256, 32, 32, 512, 3, 3),
(128, 3, 32, 32, 64, 3, 3),
#(8, 256, 32, 32, 512, 3, 3),
#(8, 512, 32, 32, 1024, 3, 3)
]
for N, C, H, W, K, R, S in NCHWKRS:
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2))
torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b)
configs += [([N, C, H, W],
[C, R, S, K],
[K, C, R, S],
[N, K, H - R + 1, W - R + 1],
torch_fn,
'nc(h+r)(w+s),crsk->nkhw',
'nc(h+r)(w+s),kcrs->nkhw',
dict())]
# 3D Dense Convolution
@@ -173,6 +175,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
print(a.size(), b.size())
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
# reference output
if torch_fn:
@@ -182,7 +185,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
cmp_eqbmm = True
cmp_eqbmm = False
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()

View File

@@ -210,7 +210,7 @@ class kernel:
macros.append((k, values))
opt = libtriton.options_space()
opt.defines = macros
opt.num_warps = [2, 4, 8]
opt.num_warps = num_warps
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id

View File

@@ -110,13 +110,16 @@ __global__ void {name}(
src += "\n "
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += f", int stride_a_inner __multipleof({multipleof_a})"
src += f", int rem_delta_a __multipleof({multipleof_a})"
elif lut_mode_a == _einsum.LUT_MODE.DRAM:
src += ", int* AD __noalias __readonly __aligned(16)"
src += "\n "
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += f", int stride_b_inner __multipleof({multipleof_b})"
src += f", int rem_delta_b __multipleof({multipleof_b})"
elif lut_mode_b == _einsum.LUT_MODE.DRAM:
src += ", int* BD"
src += "\n"
for ptr in subscripted:
src += f", int* {ptr}"
src += """) {
@@ -142,6 +145,7 @@ __global__ void {name}(
int off_k = pid_z * div_z;
matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z);
#endif
int rem_k = matmul_k % TK;
// create ranges
"""
@@ -204,13 +208,33 @@ __global__ void {name}(
src += f"""
// prefetch
int prefetch_k = select(rem_k > 0, rem_k, TK);
bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m;
bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n;
bool checkk[TK] = {rk} < matmul_k + off_k;
bool checkk[TK] = {rk} < prefetch_k;
bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
TYPE a[TM, TK, TB] = checka ? *pa : 0;
TYPE b[TK, TN, TB] = checkb ? *pb : 0;
TYPE b[TK, TN, TB] = checkb ? *pb : 0;"""
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += rem_delta_a;"""
else:
src += """
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
src += """
pb += rem_delta_b;"""
else:
src += """
pb += (*pbdelta)[:, newaxis, newaxis];
pbdelta += TK;"""
src += f"""
// accumulate
float acc[TM, TN, TB] = 0;
for(int k = matmul_k; k > 0; k -= TK) {{
@@ -219,16 +243,13 @@ __global__ void {name}(
uint32 bits[TM, TN, TB] = bitcast<uint32[TM,TN,TB]>(acc);
acc = bitcast<TYPE[TM, TN, TB]>(bits & MASK);
#endif
"""
if not use_lut_a or not use_lut_b:
src += f"""
{rk} += TK;
"""
src += _einsum.unpack_cc(tile, axes_k, 'r', True)
checkk = k > TK;
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
a = *?(checka)pa;
b = *?(checkb)pb;"""
if use_lut_a:
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
src += """
pa += stride_a_inner;"""
@@ -237,18 +258,6 @@ __global__ void {name}(
pa += incda;
padelta += TK;
incda = (*padelta)[newaxis, :, newaxis];"""
else:
src += """
offa = """
for i, sym in enumerate(expr_a):
ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b)
stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1'
if i > 0:
src += ' + '
src += f"({ccode}) * {stride}\n "
src += """;
TYPE *pa[TM, TK, TB] = A + offa;"""
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
@@ -260,11 +269,6 @@ __global__ void {name}(
pbdelta += TK;"""
src += f"""
checkk = k > TK;
checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis];
checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis];
a = *?(checka)pa;
b = *?(checkb)pb;
}}
TYPE c[TM, TN, TB] = acc;
@@ -367,16 +371,27 @@ __global__ void {name}(
fn = sp.lambdify(args, delta, 'numpy')
# inner axes values
inner = [dims[d] for d in axes]
k = np.arange(np.prod(inner), dtype=np.int32)
inner = np.prod(inner)
rem = inner % step
rem = rem if rem > 0 else step
# k = [0, 1, ..., step,
# rem, rem + 1, ... rem + inner]
k = np.concatenate((np.arange(step),
np.arange(rem, inner))).astype(np.int32)
# nextk = [rem, 1 + rem, ..., step + rem,
# rem + step, rem + 1 + step, ..., inner + step]
nextk = np.concatenate((k[:step] + rem,
k[step:] + step))
# offsets
off = _einsum.unpack_offset(k, axes, dims)
nextoff = _einsum.unpack_offset(k + step, axes, dims)
nextoff = _einsum.unpack_offset(nextk, axes, dims)
# evaluate deltas
args = [s for s in stride]
args += [off[sk] for sk in axes]
args += [nextoff[sk] for sk in axes]
args += [x for _, x in arrays]
delta = fn(*args)
return delta, _einsum.lut_mode(delta[:-step])
return delta, _einsum.lut_mode(delta[step:-step])
############################
## Einsum parsing
@@ -525,12 +540,17 @@ __global__ void {name}(
alpha, M, N, K, div_m] +\
dim_m + dim_n + dim_k + dim_b +\
stride_a + stride_b + stride_c
if lut_mode_a != _einsum.LUT_MODE.CONSTANT:
delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda()
self.args += [delta_a]
if lut_mode_b != _einsum.LUT_MODE.CONSTANT:
delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda()
self.args += [delta_b]
# LUT for A
if lut_mode_a == _einsum.LUT_MODE.SCALAR:
self.args += [delta_a[TK], delta_a[0]]
if lut_mode_a == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_a).cuda()]
# LUT for B
if lut_mode_b == _einsum.LUT_MODE.SCALAR:
self.args += [delta_b[TK], delta_b[0]]
if lut_mode_b == _einsum.LUT_MODE.DRAM:
self.args += [torch.from_numpy(delta_b).cuda()]
# Einsum dependents
self.args += arrays
self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) *
triton.cdiv(N, opt.d('TN')),
@@ -551,6 +571,7 @@ __global__ void {name}(
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
TM, TN, TB, TZ = 64, 64, 1, 1
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
if mask:
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
@@ -613,6 +634,7 @@ __global__ void {name}(
ctx.matmul_K = instance.matmul_K
ctx.bench = bench
ctx.forward_ms = speed
ctx.mask = mask
ctx.save_for_backward(a, b)
return c
@@ -621,7 +643,7 @@ __global__ void {name}(
############################
@staticmethod
def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask):
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
for i, expr in enumerate(sym_x):
if expr.is_symbol:
continue
@@ -652,9 +674,9 @@ __global__ void {name}(
expr_a = _einsum.sym_to_expr(sym_a)
expr_b = _einsum.sym_to_expr(sym_b)
expr_c = _einsum.sym_to_expr(sym_c)
expr = f'{expr_c},{expr_b}->{expr_a}'
da = einsum(expr, dy, b, a.shape, False)
return None, da, None, None, None
da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask)
db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask)
return None, da, db, None, None
einsum = _einsum.apply