[PYTHON][OPS][EINSUM] Now throwing error for automatic differentiation
of extended einsum
This commit is contained in:
committed by
Philippe Tillet
parent
4e50ef4076
commit
30f77e9ec5
@@ -13,7 +13,7 @@ import re
|
|||||||
from sympy.printing.ccode import C89CodePrinter
|
from sympy.printing.ccode import C89CodePrinter
|
||||||
|
|
||||||
|
|
||||||
class _einsum(triton.function):
|
class _einsum(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
@@ -495,7 +495,7 @@ __global__ void {name}(
|
|||||||
dims.update(dims_b)
|
dims.update(dims_b)
|
||||||
dims.update(dims_c)
|
dims.update(dims_c)
|
||||||
# look-up tables
|
# look-up tables
|
||||||
TK = 16 if dtype == triton.fw.torch.float16 else 8
|
TK = 16 if dtype == torch.float16 else 8
|
||||||
arrays = [(x, arrays[x]) for x in subscripted]
|
arrays = [(x, arrays[x]) for x in subscripted]
|
||||||
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
|
delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays)
|
||||||
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
|
delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays)
|
||||||
@@ -580,17 +580,15 @@ __global__ void {name}(
|
|||||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||||
if mask:
|
if mask:
|
||||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||||
# information on compute
|
# save information on the operation
|
||||||
self.dtype = dtype
|
self.expr_a = expr_a
|
||||||
self.flops = 2 * B * M * N * K
|
self.expr_b = expr_b
|
||||||
self.sym_a = sym_a
|
self.expr_c = expr_c
|
||||||
self.sym_b = sym_b
|
|
||||||
self.sym_c = sym_c
|
|
||||||
# save equivalent mat-mul dimensions
|
|
||||||
self.matmul_B = B
|
self.matmul_B = B
|
||||||
self.matmul_M = M
|
self.matmul_M = M
|
||||||
self.matmul_N = N
|
self.matmul_N = N
|
||||||
self.matmul_K = K
|
self.matmul_K = K
|
||||||
|
self.is_extended = any([not x.is_symbol for x in sym_a + sym_b])
|
||||||
|
|
||||||
def run(self, a, b, c, bench):
|
def run(self, a, b, c, bench):
|
||||||
self.args[self.pos_a] = a
|
self.args[self.pos_a] = a
|
||||||
@@ -606,79 +604,63 @@ __global__ void {name}(
|
|||||||
############################
|
############################
|
||||||
|
|
||||||
instance_cache = dict()
|
instance_cache = dict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, einsum, a, b, **kwargs):
|
def forward(ctx, expr, a, b, output, mask=None, arrays=dict(), bench=False):
|
||||||
bench = kwargs['bench'] if 'bench' in kwargs else False
|
|
||||||
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
|
|
||||||
mask = kwargs['mask'] if 'mask' in kwargs else None
|
|
||||||
output = kwargs['output'] if 'output' in kwargs else None
|
|
||||||
# compile einsum instance
|
# compile einsum instance
|
||||||
cache = _einsum.instance_cache
|
cache = _einsum.instance_cache
|
||||||
key = (einsum, a.dtype,
|
key = (expr, a.dtype,
|
||||||
a.stride(), b.stride(), output.stride(),
|
a.stride(), b.stride(), output.stride(),
|
||||||
a.shape, b.shape, output.shape, mask)
|
a.shape, b.shape, output.shape, mask)
|
||||||
if key not in cache:
|
if key not in cache:
|
||||||
cache[key] = _einsum.instance(einsum, a.dtype,
|
cache[key] = _einsum.instance(expr, a.dtype,
|
||||||
a.stride(), b.stride(), output.stride(),
|
a.stride(), b.stride(), output.stride(),
|
||||||
a.shape, b.shape, arrays,
|
a.shape, b.shape, arrays,
|
||||||
mask, output.shape)
|
mask, output.shape)
|
||||||
instance = cache[key]
|
instance = cache[key]
|
||||||
speed = instance.run(a, b, output, bench)
|
# run and mark as dirty output modified in-place
|
||||||
|
perf = instance.run(a, b, output, bench)
|
||||||
|
ctx.mark_dirty(output)
|
||||||
# save information in context
|
# save information in context
|
||||||
ctx.flops = instance.flops
|
ctx.is_extended = instance.is_extended
|
||||||
ctx.sym_a = instance.sym_a
|
ctx.expr_a = instance.expr_a
|
||||||
ctx.sym_b = instance.sym_b
|
ctx.expr_b = instance.expr_b
|
||||||
ctx.sym_c = instance.sym_c
|
ctx.expr_c = instance.expr_c
|
||||||
ctx.matmul_B = instance.matmul_B
|
ctx.matmul_B = instance.matmul_B
|
||||||
ctx.matmul_M = instance.matmul_M
|
ctx.matmul_M = instance.matmul_M
|
||||||
ctx.matmul_N = instance.matmul_N
|
ctx.matmul_N = instance.matmul_N
|
||||||
ctx.matmul_K = instance.matmul_K
|
ctx.matmul_K = instance.matmul_K
|
||||||
ctx.bench = bench
|
ctx.perf = perf
|
||||||
ctx.forward_ms = speed
|
|
||||||
ctx.mask = mask
|
|
||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
## Backward
|
## Backward
|
||||||
############################
|
############################
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sym_invert(sym_c, sym_x, prefix, renamed, inverse):
|
|
||||||
for i, expr in enumerate(sym_x):
|
|
||||||
if expr.is_symbol:
|
|
||||||
continue
|
|
||||||
sc = [x for x in expr.free_symbols if x in sym_c][0]
|
|
||||||
sx = sp.symbols(f'{prefix}{i}')
|
|
||||||
renamed[expr] = sx
|
|
||||||
inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sym_to_expr(sym):
|
|
||||||
res = [f'({x})' for x in sym]
|
|
||||||
res = ''.join(res)
|
|
||||||
return res
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dy):
|
def backward(ctx, dy):
|
||||||
|
if ctx.is_extended:
|
||||||
|
raise NotImplementedError('Automatic differentiation for extended einsum not yet implemented;'
|
||||||
|
' print write your own autograd function')
|
||||||
a, b = ctx.saved_tensors
|
a, b = ctx.saved_tensors
|
||||||
sym_a = ctx.sym_a
|
expr_a = ctx.expr_a
|
||||||
sym_b = ctx.sym_b
|
expr_b = ctx.expr_b
|
||||||
sym_c = ctx.sym_c
|
expr_c = ctx.expr_c
|
||||||
inverse = dict()
|
# gradient of first argument
|
||||||
renamed = dict()
|
da = None
|
||||||
_einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse)
|
if ctx.needs_input_grad[1]:
|
||||||
_einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse)
|
da = torch.empty_like(a)
|
||||||
sym_a = [renamed[x] if x in renamed else x for x in sym_a]
|
einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da)
|
||||||
sym_b = [renamed[x] if x in renamed else x for x in sym_b]
|
# gradient of second argument
|
||||||
sym_c = [inverse[x] if x in inverse else x for x in sym_c]
|
db = None
|
||||||
expr_a = _einsum.sym_to_expr(sym_a)
|
if ctx.needs_input_grad[2]:
|
||||||
expr_b = _einsum.sym_to_expr(sym_b)
|
db = torch.empty_like(b)
|
||||||
expr_c = _einsum.sym_to_expr(sym_c)
|
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
|
||||||
da = einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, a.shape, mask=ctx.mask)
|
return None, da, db, None, None, None, None
|
||||||
db = einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, b.shape, mask=ctx.mask)
|
|
||||||
return None, da, db, None, None
|
|
||||||
|
|
||||||
|
|
||||||
einsum = _einsum.apply
|
def einsum(expr, a, b, output,
|
||||||
|
mask=None, arrays=dict(),
|
||||||
|
bench=False):
|
||||||
|
return _einsum.apply(expr, a, b, output, mask, arrays, bench)
|
Reference in New Issue
Block a user