[PYTHON][OPS][EINSUM] Now throwing error for automatic differentiation
of extended einsum
This commit is contained in:
committed by
Philippe Tillet
parent
4e50ef4076
commit
30f77e9ec5
@@ -13,7 +13,7 @@ import re
|
||||
from sympy.printing.ccode import C89CodePrinter
|
||||
|
||||
|
||||
class _einsum(triton.function):
|
||||
class _einsum(torch.autograd.Function):
|
||||
|
||||
|
||||
#############################
|
||||
@@ -495,7 +495,7 @@ __global__ void {name}(
|
||||
dims.update(dims_b)
|
||||
dims.update(dims_c)
|
||||
# look-up tables
|
||||
TK = 16 if dtype == triton.fw.torch.float16 else 8
|
||||
TK = 16 if dtype == torch.float16 else 8
|
||||
arrays = [(x, arrays[x]) for x in subscripted]
|
||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
|
||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
|
||||
@@ -580,17 +580,15 @@ __global__ void {name}(
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||
if mask:
|
||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||
# information on compute
|
||||
self.dtype = dtype
|
||||
self.flops = 2 * B * M * N * K
|
||||
self.sym_a = sym_a
|
||||
self.sym_b = sym_b
|
||||
self.sym_c = sym_c
|
||||
# save equivalent mat-mul dimensions
|
||||
# save information on the operation
|
||||
self.expr_a = expr_a
|
||||
self.expr_b = expr_b
|
||||
self.expr_c = expr_c
|
||||
self.matmul_B = B
|
||||
self.matmul_M = M
|
||||
self.matmul_N = N
|
||||
self.matmul_K = K
|
||||
self.is_extended = any([not x.is_symbol for x in sym_a + sym_b])
|
||||
|
||||
def run(self, a, b, c, bench):
|
||||
self.args[self.pos_a] = a
|
||||
@@ -606,79 +604,63 @@ __global__ void {name}(
|
||||
############################
|
||||
|
||||
instance_cache = dict()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, einsum, a, b, **kwargs):
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else False
|
||||
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
|
||||
mask = kwargs['mask'] if 'mask' in kwargs else None
|
||||
output = kwargs['output'] if 'output' in kwargs else None
|
||||
def forward(ctx, expr, a, b, output, mask=None, arrays=dict(), bench=False):
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
key = (einsum, a.dtype,
|
||||
key = (expr, a.dtype,
|
||||
a.stride(), b.stride(), output.stride(),
|
||||
a.shape, b.shape, output.shape, mask)
|
||||
if key not in cache:
|
||||
cache[key] = _einsum.instance(einsum, a.dtype,
|
||||
cache[key] = _einsum.instance(expr, a.dtype,
|
||||
a.stride(), b.stride(), output.stride(),
|
||||
a.shape, b.shape, arrays,
|
||||
mask, output.shape)
|
||||
instance = cache[key]
|
||||
speed = instance.run(a, b, output, bench)
|
||||
# run and mark as dirty output modified in-place
|
||||
perf = instance.run(a, b, output, bench)
|
||||
ctx.mark_dirty(output)
|
||||
# save information in context
|
||||
ctx.flops = instance.flops
|
||||
ctx.sym_a = instance.sym_a
|
||||
ctx.sym_b = instance.sym_b
|
||||
ctx.sym_c = instance.sym_c
|
||||
ctx.is_extended = instance.is_extended
|
||||
ctx.expr_a = instance.expr_a
|
||||
ctx.expr_b = instance.expr_b
|
||||
ctx.expr_c = instance.expr_c
|
||||
ctx.matmul_B = instance.matmul_B
|
||||
ctx.matmul_M = instance.matmul_M
|
||||
ctx.matmul_N = instance.matmul_N
|
||||
ctx.matmul_K = instance.matmul_K
|
||||
ctx.bench = bench
|
||||
ctx.forward_ms = speed
|
||||
ctx.mask = mask
|
||||
ctx.perf = perf
|
||||
ctx.save_for_backward(a, b)
|
||||
return output
|
||||
|
||||
|
||||
############################
|
||||
## Backward
|
||||
############################
|
||||
|
||||
@staticmethod
|
||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
|
||||
for i, expr in enumerate(sym_x):
|
||||
if expr.is_symbol:
|
||||
continue
|
||||
sc = [x for x in expr.free_symbols if x in sym_c][0]
|
||||
sx = sp.symbols(f'{prefix}{i}')
|
||||
renamed[expr] = sx
|
||||
inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0]
|
||||
|
||||
@staticmethod
|
||||
def sym_to_expr(sym):
|
||||
res = [f'({x})' for x in sym]
|
||||
res = ''.join(res)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
if ctx.is_extended:
|
||||
raise NotImplementedError('Automatic differentiation for extended einsum not yet implemented;'
|
||||
' print write your own autograd function')
|
||||
a, b = ctx.saved_tensors
|
||||
sym_a = ctx.sym_a
|
||||
sym_b = ctx.sym_b
|
||||
sym_c = ctx.sym_c
|
||||
inverse = dict()
|
||||
renamed = dict()
|
||||
_einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse)
|
||||
_einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse)
|
||||
sym_a = [renamed[x] if x in renamed else x for x in sym_a]
|
||||
sym_b = [renamed[x] if x in renamed else x for x in sym_b]
|
||||
sym_c = [inverse[x] if x in inverse else x for x in sym_c]
|
||||
expr_a = _einsum.sym_to_expr(sym_a)
|
||||
expr_b = _einsum.sym_to_expr(sym_b)
|
||||
expr_c = _einsum.sym_to_expr(sym_c)
|
||||
da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask)
|
||||
db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask)
|
||||
return None, da, db, None, None
|
||||
expr_a = ctx.expr_a
|
||||
expr_b = ctx.expr_b
|
||||
expr_c = ctx.expr_c
|
||||
# gradient of first argument
|
||||
da = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
da = torch.empty_like(a)
|
||||
einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da)
|
||||
# gradient of second argument
|
||||
db = None
|
||||
if ctx.needs_input_grad[2]:
|
||||
db = torch.empty_like(b)
|
||||
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
|
||||
return None, da, db, None, None, None, None
|
||||
|
||||
|
||||
einsum = _einsum.apply
|
||||
def einsum(expr, a, b, output,
|
||||
mask=None, arrays=dict(),
|
||||
bench=False):
|
||||
return _einsum.apply(expr, a, b, output, mask, arrays, bench)
|
Reference in New Issue
Block a user