[PYTHON][OP][EINSUM] simplified API

This commit is contained in:
Philippe Tillet
2020-02-19 23:42:22 -05:00
committed by Philippe Tillet
parent 26fd884d96
commit 4e50ef4076

View File

@@ -467,7 +467,7 @@ __global__ void {name}(
locks = None
kernel_cache = dict()
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask):
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c):
# parse symbols
expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->")
@@ -608,26 +608,23 @@ __global__ void {name}(
instance_cache = dict()
@staticmethod
def forward(ctx, einsum, a, b, shape_c, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False
def forward(ctx, einsum, a, b, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
mask = kwargs['mask'] if 'mask' in kwargs else None
# allocate output
dtype = a.dtype
c = triton.empty(shape_c, dtype=dtype)
mask = kwargs['mask'] if 'mask' in kwargs else None
output = kwargs['output'] if 'output' in kwargs else None
# compile einsum instance
cache = _einsum.instance_cache
key = (einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape,
mask)
key = (einsum, a.dtype,
a.stride(), b.stride(), output.stride(),
a.shape, b.shape, output.shape, mask)
if key not in cache:
cache[key] = _einsum.instance(einsum, dtype,
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape, arrays,
mask)
cache[key] = _einsum.instance(einsum, a.dtype,
a.stride(), b.stride(), output.stride(),
a.shape, b.shape, arrays,
mask, output.shape)
instance = cache[key]
speed = instance.run(a, b, c, bench)
speed = instance.run(a, b, output, bench)
# save information in context
ctx.flops = instance.flops
ctx.sym_a = instance.sym_a
@@ -641,7 +638,7 @@ __global__ void {name}(
ctx.forward_ms = speed
ctx.mask = mask
ctx.save_for_backward(a, b)
return c
return output
############################
## Backward