From 3816f2f2595a9803c0c70180226cd2a01fff9b09 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 14 Feb 2020 12:41:47 -0500 Subject: [PATCH] [PYTHON][EINSUM] Now handling reduction sizes that are not a multiple of TK --- lib/codegen/selection/generator.cc | 2 +- lib/runtime/function.cc | 1 + python/examples/einsum.py | 47 +++++------ python/triton/kernel.py | 2 +- python/triton/ops/einsum.py | 120 +++++++++++++++++------------ 5 files changed, 99 insertions(+), 73 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index ae7d0e876..8321a9948 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -369,7 +369,7 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) { ((PHINode*)current_result)->addIncoming(result_then, mask_then_bb); Value *result_false = false_values->get_value(idx); if(result_then->getType()->isVectorTy()) - result_false = builder_->CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType())); + result_false = builder_->CreateVectorSplat(vector_size, result_false); ((PHINode*)current_result)->addIncoming(result_false, current_bb); } else diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 70e4df12f..e77e5772b 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -255,6 +255,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c if(allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr(); barriers.run(module); +// ir::print(module, std::cout); isel.visit(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); diff --git a/python/examples/einsum.py b/python/examples/einsum.py index 340298be5..a971c6683 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -5,6 +5,8 @@ import numpy as np #import utils from time import time +torch.manual_seed(0) + #torch.backends.cudnn.benchmark = True configs = [] @@ -31,15 +33,15 @@ MNK = [ # (127008, 768, 576) ] -for M, N, K in MNK: - matmul = lambda a, b: torch.matmul(a, b) - configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] -for M, N, K in MNK: - matmul = lambda a, b: torch.matmul(a.t(), b) - configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] -for M, N, K in MNK: - matmul = lambda a, b: torch.matmul(a, b.t()) - configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())] +#for M, N, K in MNK: +# matmul = lambda a, b: torch.matmul(a, b) +# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] +#for M, N, K in MNK: +# matmul = lambda a, b: torch.matmul(a.t(), b) +# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] +#for M, N, K in MNK: +# matmul = lambda a, b: torch.matmul(a, b.t()) +# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())] # Relative attention NTHSE = [ @@ -70,16 +72,16 @@ NTHSE = [ # (128, 1024, 8, 256, 256), #(128, 1024, 8, 256, 512) ] -for N, T, H, S, E in NTHSE: - configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())] -for N, T, H, S, E in NTHSE: - configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())] -for N, T, H, S, E in NTHSE: - configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())] +#for N, T, H, S, E in NTHSE: +# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())] +#for N, T, H, S, E in NTHSE: +# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())] +#for N, T, H, S, E in NTHSE: +# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())] # 1D Dense convolution NCHKR = [ - (1, 1152, 12602, 512, 3) + #(1, 1152, 12602, 512, 3) ] for N, C, H, K, R in NCHKR: torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1)) @@ -93,17 +95,17 @@ for N, C, H, K, R in NCHKR: # 2D Dense convolution NCHWKRS = [ #(8, 64, 128, 128, 768, 3, 3), - (8, 128, 64, 64, 256, 3, 3), - (8, 256, 32, 32, 512, 3, 3), + (128, 3, 32, 32, 64, 3, 3), + #(8, 256, 32, 32, 512, 3, 3), #(8, 512, 32, 32, 1024, 3, 3) ] for N, C, H, W, K, R, S in NCHWKRS: - torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) + torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b) configs += [([N, C, H, W], - [C, R, S, K], + [K, C, R, S], [N, K, H - R + 1, W - R + 1], torch_fn, - 'nc(h+r)(w+s),crsk->nkhw', + 'nc(h+r)(w+s),kcrs->nkhw', dict())] # 3D Dense Convolution @@ -173,6 +175,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda() # triton output + print(a.size(), b.size()) tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True) # reference output if torch_fn: @@ -182,7 +185,7 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: # performance relative to equivalent matrix multiplication ctx = triton.ctx_registry[tc] B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K - cmp_eqbmm = True + cmp_eqbmm = False if cmp_eqbmm: a = torch.rand(B, M, K).type(dtype).cuda() b = torch.rand(B, K, N).type(dtype).cuda() diff --git a/python/triton/kernel.py b/python/triton/kernel.py index edec0af12..ab7c3f49e 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -210,7 +210,7 @@ class kernel: macros.append((k, values)) opt = libtriton.options_space() opt.defines = macros - opt.num_warps = [2, 4, 8] + opt.num_warps = num_warps # create unique id for this op op_id = libtriton.make_op_id() self.fw_id[key] = op_id diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 4bce9729d..39f0c78b5 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -110,13 +110,16 @@ __global__ void {name}( src += "\n " if lut_mode_a == _einsum.LUT_MODE.SCALAR: src += f", int stride_a_inner __multipleof({multipleof_a})" + src += f", int rem_delta_a __multipleof({multipleof_a})" elif lut_mode_a == _einsum.LUT_MODE.DRAM: src += ", int* AD __noalias __readonly __aligned(16)" src += "\n " if lut_mode_b == _einsum.LUT_MODE.SCALAR: src += f", int stride_b_inner __multipleof({multipleof_b})" + src += f", int rem_delta_b __multipleof({multipleof_b})" elif lut_mode_b == _einsum.LUT_MODE.DRAM: src += ", int* BD" + src += "\n" for ptr in subscripted: src += f", int* {ptr}" src += """) { @@ -142,6 +145,7 @@ __global__ void {name}( int off_k = pid_z * div_z; matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z); #endif + int rem_k = matmul_k % TK; // create ranges """ @@ -204,13 +208,33 @@ __global__ void {name}( src += f""" // prefetch + int prefetch_k = select(rem_k > 0, rem_k, TK); bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m; bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n; - bool checkk[TK] = {rk} < matmul_k + off_k; + bool checkk[TK] = {rk} < prefetch_k; bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; TYPE a[TM, TK, TB] = checka ? *pa : 0; - TYPE b[TK, TN, TB] = checkb ? *pb : 0; + TYPE b[TK, TN, TB] = checkb ? *pb : 0;""" + + if lut_mode_a == _einsum.LUT_MODE.SCALAR: + src += """ + pa += rem_delta_a;""" + else: + src += """ + pa += incda; + padelta += TK; + incda = (*padelta)[newaxis, :, newaxis];""" + + if lut_mode_b == _einsum.LUT_MODE.SCALAR: + src += """ + pb += rem_delta_b;""" + else: + src += """ + pb += (*pbdelta)[:, newaxis, newaxis]; + pbdelta += TK;""" + + src += f""" // accumulate float acc[TM, TN, TB] = 0; for(int k = matmul_k; k > 0; k -= TK) {{ @@ -219,36 +243,21 @@ __global__ void {name}( uint32 bits[TM, TN, TB] = bitcast(acc); acc = bitcast(bits & MASK); #endif - """ - if not use_lut_a or not use_lut_b: - src += f""" - {rk} += TK; -""" - src += _einsum.unpack_cc(tile, axes_k, 'r', True) - + checkk = k > TK; + checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; + checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; + a = *?(checka)pa; + b = *?(checkb)pb;""" - if use_lut_a: - if lut_mode_a == _einsum.LUT_MODE.SCALAR: - src += """ - pa += stride_a_inner;""" - else: - src += """ - pa += incda; - padelta += TK; - incda = (*padelta)[newaxis, :, newaxis];""" + if lut_mode_a == _einsum.LUT_MODE.SCALAR: + src += """ + pa += stride_a_inner;""" else: src += """ - offa = """ - for i, sym in enumerate(expr_a): - ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b) - stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1' - if i > 0: - src += ' + ' - src += f"({ccode}) * {stride}\n " - src += """; - TYPE *pa[TM, TK, TB] = A + offa;""" - + pa += incda; + padelta += TK; + incda = (*padelta)[newaxis, :, newaxis];""" if lut_mode_b == _einsum.LUT_MODE.SCALAR: @@ -260,11 +269,6 @@ __global__ void {name}( pbdelta += TK;""" src += f""" - checkk = k > TK; - checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; - checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; - a = *?(checka)pa; - b = *?(checkb)pb; }} TYPE c[TM, TN, TB] = acc; @@ -367,16 +371,27 @@ __global__ void {name}( fn = sp.lambdify(args, delta, 'numpy') # inner axes values inner = [dims[d] for d in axes] - k = np.arange(np.prod(inner), dtype=np.int32) + inner = np.prod(inner) + rem = inner % step + rem = rem if rem > 0 else step + # k = [0, 1, ..., step, + # rem, rem + 1, ... rem + inner] + k = np.concatenate((np.arange(step), + np.arange(rem, inner))).astype(np.int32) + # nextk = [rem, 1 + rem, ..., step + rem, + # rem + step, rem + 1 + step, ..., inner + step] + nextk = np.concatenate((k[:step] + rem, + k[step:] + step)) + # offsets off = _einsum.unpack_offset(k, axes, dims) - nextoff = _einsum.unpack_offset(k + step, axes, dims) + nextoff = _einsum.unpack_offset(nextk, axes, dims) # evaluate deltas args = [s for s in stride] args += [off[sk] for sk in axes] args += [nextoff[sk] for sk in axes] args += [x for _, x in arrays] delta = fn(*args) - return delta, _einsum.lut_mode(delta[:-step]) + return delta, _einsum.lut_mode(delta[step:-step]) ############################ ## Einsum parsing @@ -525,17 +540,22 @@ __global__ void {name}( alpha, M, N, K, div_m] +\ dim_m + dim_n + dim_k + dim_b +\ stride_a + stride_b + stride_c - if lut_mode_a != _einsum.LUT_MODE.CONSTANT: - delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda() - self.args += [delta_a] - if lut_mode_b != _einsum.LUT_MODE.CONSTANT: - delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda() - self.args += [delta_b] + # LUT for A + if lut_mode_a == _einsum.LUT_MODE.SCALAR: + self.args += [delta_a[TK], delta_a[0]] + if lut_mode_a == _einsum.LUT_MODE.DRAM: + self.args += [torch.from_numpy(delta_a).cuda()] + # LUT for B + if lut_mode_b == _einsum.LUT_MODE.SCALAR: + self.args += [delta_b[TK], delta_b[0]] + if lut_mode_b == _einsum.LUT_MODE.DRAM: + self.args += [torch.from_numpy(delta_b).cuda()] + # Einsum dependents self.args += arrays self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) * - triton.cdiv(N, opt.d('TN')), - triton.cdiv(B, opt.d('TB')), - opt.d('TZ')] + triton.cdiv(N, opt.d('TN')), + triton.cdiv(B, opt.d('TB')), + opt.d('TZ')] # position of dynamic arguments self.pos_a = 0 self.pos_b = 1 @@ -551,6 +571,7 @@ __global__ void {name}( TZ = [x for x in [1, 2, 4, 8, 16, 32] \ if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] + TM, TN, TB, TZ = 64, 64, 1, 1 self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} if mask: self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10) @@ -613,6 +634,7 @@ __global__ void {name}( ctx.matmul_K = instance.matmul_K ctx.bench = bench ctx.forward_ms = speed + ctx.mask = mask ctx.save_for_backward(a, b) return c @@ -621,7 +643,7 @@ __global__ void {name}( ############################ @staticmethod - def sym_invert(sym_c, sym_x, prefix, renamed, inverse, mask): + def sym_invert(sym_c, sym_x, prefix, renamed, inverse): for i, expr in enumerate(sym_x): if expr.is_symbol: continue @@ -652,9 +674,9 @@ __global__ void {name}( expr_a = _einsum.sym_to_expr(sym_a) expr_b = _einsum.sym_to_expr(sym_b) expr_c = _einsum.sym_to_expr(sym_c) - expr = f'{expr_c},{expr_b}->{expr_a}' - da = einsum(expr, dy, b, a.shape, False) - return None, da, None, None, None + da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask) + db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask) + return None, da, db, None, None einsum = _einsum.apply \ No newline at end of file