[PYTHON][OP][EINSUM] simplified API
This commit is contained in:
committed by
Philippe Tillet
parent
26fd884d96
commit
4e50ef4076
@@ -467,7 +467,7 @@ __global__ void {name}(
|
||||
locks = None
|
||||
kernel_cache = dict()
|
||||
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask):
|
||||
def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c):
|
||||
# parse symbols
|
||||
expr_a, expr_bc = einsum.split(",")
|
||||
expr_b, expr_c = expr_bc.split("->")
|
||||
@@ -608,26 +608,23 @@ __global__ void {name}(
|
||||
instance_cache = dict()
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, einsum, a, b, shape_c, **kwargs):
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else False
|
||||
def forward(ctx, einsum, a, b, **kwargs):
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else False
|
||||
arrays = kwargs['arrays'] if 'arrays' in kwargs else dict()
|
||||
mask = kwargs['mask'] if 'mask' in kwargs else None
|
||||
# allocate output
|
||||
dtype = a.dtype
|
||||
c = triton.empty(shape_c, dtype=dtype)
|
||||
mask = kwargs['mask'] if 'mask' in kwargs else None
|
||||
output = kwargs['output'] if 'output' in kwargs else None
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
key = (einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape,
|
||||
mask)
|
||||
key = (einsum, a.dtype,
|
||||
a.stride(), b.stride(), output.stride(),
|
||||
a.shape, b.shape, output.shape, mask)
|
||||
if key not in cache:
|
||||
cache[key] = _einsum.instance(einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape, arrays,
|
||||
mask)
|
||||
cache[key] = _einsum.instance(einsum, a.dtype,
|
||||
a.stride(), b.stride(), output.stride(),
|
||||
a.shape, b.shape, arrays,
|
||||
mask, output.shape)
|
||||
instance = cache[key]
|
||||
speed = instance.run(a, b, c, bench)
|
||||
speed = instance.run(a, b, output, bench)
|
||||
# save information in context
|
||||
ctx.flops = instance.flops
|
||||
ctx.sym_a = instance.sym_a
|
||||
@@ -641,7 +638,7 @@ __global__ void {name}(
|
||||
ctx.forward_ms = speed
|
||||
ctx.mask = mask
|
||||
ctx.save_for_backward(a, b)
|
||||
return c
|
||||
return output
|
||||
|
||||
############################
|
||||
## Backward
|
||||
|
Reference in New Issue
Block a user