[PYTHON][OPS][EINSUM] Now throwing error for automatic differentiation

of extended einsum
This commit is contained in:
Philippe Tillet
2020-02-20 17:29:01 -05:00
committed by Philippe Tillet
parent 4e50ef4076
commit 30f77e9ec5

View File

@@ -13,7 +13,7 @@ import re
from sympy.printing.ccode import C89CodePrinter
class _einsum(triton.function):
class _einsum(torch.autograd.Function):
#############################
@@ -495,7 +495,7 @@ __global__ void {name}(
dims.update(dims_b)
dims.update(dims_c)
# look-up tables
TK = 16 if dtype == triton.fw.torch.float16 else 8
TK = 16 if dtype == torch.float16 else 8
arrays = [(x, arrays[x]) for x in subscripted]
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
@@ -580,17 +580,15 @@ __global__ void {name}(
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
if mask:
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
# information on compute
self.dtype = dtype
self.flops = 2 * B * M * N * K
self.sym_a = sym_a
self.sym_b = sym_b
self.sym_c = sym_c
# save equivalent mat-mul dimensions
# save information on the operation
self.expr_a = expr_a
self.expr_b = expr_b
self.expr_c = expr_c
self.matmul_B = B
self.matmul_M = M
self.matmul_N = N
self.matmul_K = K
self.is_extended = any([not x.is_symbol for x in sym_a + sym_b])
def run(self, a, b, c, bench):
self.args[self.pos_a] = a
@@ -606,79 +604,63 @@ __global__ void {name}(
############################
instance_cache = dict()
@staticmethod
def forward(ctx, einsum, a, b, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
mask = kwargs['mask'] if 'mask' in kwargs else None
output = kwargs['output'] if 'output' in kwargs else None
def forward(ctx, expr, a, b, output, mask=None, arrays=dict(), bench=False):
# compile einsum instance
cache = _einsum.instance_cache
key = (einsum, a.dtype,
key = (expr, a.dtype,
a.stride(), b.stride(), output.stride(),
a.shape, b.shape, output.shape, mask)
if key not in cache:
cache[key] = _einsum.instance(einsum, a.dtype,
cache[key] = _einsum.instance(expr, a.dtype,
a.stride(), b.stride(), output.stride(),
a.shape, b.shape, arrays,
mask, output.shape)
instance = cache[key]
speed = instance.run(a, b, output, bench)
# run and mark as dirty output modified in-place
perf = instance.run(a, b, output, bench)
ctx.mark_dirty(output)
# save information in context
ctx.flops = instance.flops
ctx.sym_a = instance.sym_a
ctx.sym_b = instance.sym_b
ctx.sym_c = instance.sym_c
ctx.is_extended = instance.is_extended
ctx.expr_a = instance.expr_a
ctx.expr_b = instance.expr_b
ctx.expr_c = instance.expr_c
ctx.matmul_B = instance.matmul_B
ctx.matmul_M = instance.matmul_M
ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K
ctx.bench = bench
ctx.forward_ms = speed
ctx.mask = mask
ctx.perf = perf
ctx.save_for_backward(a, b)
return output
############################
## Backward
############################
@staticmethod
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
for i, expr in enumerate(sym_x):
if expr.is_symbol:
continue
sc = [x for x in expr.free_symbols if x in sym_c][0]
sx = sp.symbols(f'{prefix}{i}')
renamed[expr] = sx
inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0]
@staticmethod
def sym_to_expr(sym):
res = [f'({x})' for x in sym]
res = ''.join(res)
return res
@staticmethod
def backward(ctx, dy):
if ctx.is_extended:
raise NotImplementedError('Automatic differentiation for extended einsum not yet implemented;'
' print write your own autograd function')
a, b = ctx.saved_tensors
sym_a = ctx.sym_a
sym_b = ctx.sym_b
sym_c = ctx.sym_c
inverse = dict()
renamed = dict()
_einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse)
_einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse)
sym_a = [renamed[x] if x in renamed else x for x in sym_a]
sym_b = [renamed[x] if x in renamed else x for x in sym_b]
sym_c = [inverse[x] if x in inverse else x for x in sym_c]
expr_a = _einsum.sym_to_expr(sym_a)
expr_b = _einsum.sym_to_expr(sym_b)
expr_c = _einsum.sym_to_expr(sym_c)
da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask)
db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask)
return None, da, db, None, None
expr_a = ctx.expr_a
expr_b = ctx.expr_b
expr_c = ctx.expr_c
# gradient of first argument
da = None
if ctx.needs_input_grad[1]:
da = torch.empty_like(a)
einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da)
# gradient of second argument
db = None
if ctx.needs_input_grad[2]:
db = torch.empty_like(b)
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
return None, da, db, None, None, None, None
einsum = _einsum.apply
def einsum(expr, a, b, output,
mask=None, arrays=dict(),
bench=False):
return _einsum.apply(expr, a, b, output, mask, arrays, bench)