diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index aa455b7fb..1bd318b13 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -13,7 +13,7 @@ import re from sympy.printing.ccode import C89CodePrinter -class _einsum(triton.function): +class _einsum(torch.autograd.Function): ############################# @@ -495,7 +495,7 @@ __global__ void {name}( dims.update(dims_b) dims.update(dims_c) # look-up tables - TK = 16 if dtype == triton.fw.torch.float16 else 8 + TK = 16 if dtype == torch.float16 else 8 arrays = [(x, arrays[x]) for x in subscripted] delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays) delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays) @@ -580,17 +580,15 @@ __global__ void {name}( self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} if mask: self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10) - # information on compute - self.dtype = dtype - self.flops = 2 * B * M * N * K - self.sym_a = sym_a - self.sym_b = sym_b - self.sym_c = sym_c - # save equivalent mat-mul dimensions + # save information on the operation + self.expr_a = expr_a + self.expr_b = expr_b + self.expr_c = expr_c self.matmul_B = B self.matmul_M = M self.matmul_N = N self.matmul_K = K + self.is_extended = any([not x.is_symbol for x in sym_a + sym_b]) def run(self, a, b, c, bench): self.args[self.pos_a] = a @@ -606,79 +604,63 @@ __global__ void {name}( ############################ instance_cache = dict() - @staticmethod - def forward(ctx, einsum, a, b, **kwargs): - bench = kwargs['bench'] if 'bench' in kwargs else False - arrays = kwargs['arrays'] if 'arrays' in kwargs else dict() - mask = kwargs['mask'] if 'mask' in kwargs else None - output = kwargs['output'] if 'output' in kwargs else None + def forward(ctx, expr, a, b, output, mask=None, arrays=dict(), bench=False): # compile einsum instance cache = _einsum.instance_cache - key = (einsum, a.dtype, + key = (expr, a.dtype, a.stride(), b.stride(), output.stride(), a.shape, b.shape, output.shape, mask) if key not in cache: - cache[key] = _einsum.instance(einsum, a.dtype, + cache[key] = _einsum.instance(expr, a.dtype, a.stride(), b.stride(), output.stride(), a.shape, b.shape, arrays, mask, output.shape) instance = cache[key] - speed = instance.run(a, b, output, bench) + # run and mark as dirty output modified in-place + perf = instance.run(a, b, output, bench) + ctx.mark_dirty(output) # save information in context - ctx.flops = instance.flops - ctx.sym_a = instance.sym_a - ctx.sym_b = instance.sym_b - ctx.sym_c = instance.sym_c + ctx.is_extended = instance.is_extended + ctx.expr_a = instance.expr_a + ctx.expr_b = instance.expr_b + ctx.expr_c = instance.expr_c ctx.matmul_B = instance.matmul_B ctx.matmul_M = instance.matmul_M ctx.matmul_N = instance.matmul_N ctx.matmul_K = instance.matmul_K - ctx.bench = bench - ctx.forward_ms = speed - ctx.mask = mask + ctx.perf = perf ctx.save_for_backward(a, b) return output + ############################ ## Backward ############################ - @staticmethod - def sym_invert(sym_c, sym_x, prefix, renamed, inverse): - for i, expr in enumerate(sym_x): - if expr.is_symbol: - continue - sc = [x for x in expr.free_symbols if x in sym_c][0] - sx = sp.symbols(f'{prefix}{i}') - renamed[expr] = sx - inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0] - - @staticmethod - def sym_to_expr(sym): - res = [f'({x})' for x in sym] - res = ''.join(res) - return res - @staticmethod def backward(ctx, dy): + if ctx.is_extended: + raise NotImplementedError('Automatic differentiation for extended einsum not yet implemented;' + ' print write your own autograd function') a, b = ctx.saved_tensors - sym_a = ctx.sym_a - sym_b = ctx.sym_b - sym_c = ctx.sym_c - inverse = dict() - renamed = dict() - _einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse) - _einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse) - sym_a = [renamed[x] if x in renamed else x for x in sym_a] - sym_b = [renamed[x] if x in renamed else x for x in sym_b] - sym_c = [inverse[x] if x in inverse else x for x in sym_c] - expr_a = _einsum.sym_to_expr(sym_a) - expr_b = _einsum.sym_to_expr(sym_b) - expr_c = _einsum.sym_to_expr(sym_c) - da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask) - db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask) - return None, da, db, None, None + expr_a = ctx.expr_a + expr_b = ctx.expr_b + expr_c = ctx.expr_c + # gradient of first argument + da = None + if ctx.needs_input_grad[1]: + da = torch.empty_like(a) + einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da) + # gradient of second argument + db = None + if ctx.needs_input_grad[2]: + db = torch.empty_like(b) + einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db) + return None, da, db, None, None, None, None -einsum = _einsum.apply \ No newline at end of file +def einsum(expr, a, b, output, + mask=None, arrays=dict(), + bench=False): + return _einsum.apply(expr, a, b, output, mask, arrays, bench) \ No newline at end of file