[PYTHON][OP][EINSUM] simplified API

This commit is contained in:
Philippe Tillet
2020-02-19 23:42:22 -05:00
committed by Philippe Tillet
parent 26fd884d96
commit 4e50ef4076

View File

@@ -467,7 +467,7 @@ __global__ void {name}(
locks = None locks = None
kernel_cache = dict() kernel_cache = dict()
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask): def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c):
# parse symbols # parse symbols
expr_a, expr_bc = einsum.split(",") expr_a, expr_bc = einsum.split(",")
expr_b, expr_c = expr_bc.split("->") expr_b, expr_c = expr_bc.split("->")
@@ -608,26 +608,23 @@ __global__ void {name}(
instance_cache = dict() instance_cache = dict()
@staticmethod @staticmethod
def forward(ctx, einsum, a, b, shape_c, **kwargs): def forward(ctx, einsum, a, b, **kwargs):
bench = kwargs['bench'] if 'bench' in kwargs else False bench = kwargs['bench'] if 'bench' in kwargs else False
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict() arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
mask = kwargs['mask'] if 'mask' in kwargs else None mask = kwargs['mask'] if 'mask' in kwargs else None
# allocate output output = kwargs['output'] if 'output' in kwargs else None
dtype = a.dtype
c = triton.empty(shape_c, dtype=dtype)
# compile einsum instance # compile einsum instance
cache = _einsum.instance_cache cache = _einsum.instance_cache
key = (einsum, dtype, key = (einsum, a.dtype,
a.stride(), b.stride(), c.stride(), a.stride(), b.stride(), output.stride(),
a.shape, b.shape, c.shape, a.shape, b.shape, output.shape, mask)
mask)
if key not in cache: if key not in cache:
cache[key] = _einsum.instance(einsum, dtype, cache[key] = _einsum.instance(einsum, a.dtype,
a.stride(), b.stride(), c.stride(), a.stride(), b.stride(), output.stride(),
a.shape, b.shape, c.shape, arrays, a.shape, b.shape, arrays,
mask) mask, output.shape)
instance = cache[key] instance = cache[key]
speed = instance.run(a, b, c, bench) speed = instance.run(a, b, output, bench)
# save information in context # save information in context
ctx.flops = instance.flops ctx.flops = instance.flops
ctx.sym_a = instance.sym_a ctx.sym_a = instance.sym_a
@@ -641,7 +638,7 @@ __global__ void {name}(
ctx.forward_ms = speed ctx.forward_ms = speed
ctx.mask = mask ctx.mask = mask
ctx.save_for_backward(a, b) ctx.save_for_backward(a, b)
return c return output
############################ ############################
## Backward ## Backward