From 4e50ef4076958dbe85325b96ac2dc37a849e31db Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 19 Feb 2020 23:42:22 -0500 Subject: [PATCH] [PYTHON][OP][EINSUM] simplified API --- python/triton/ops/einsum.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 7e7b2189f..aa455b7fb 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -467,7 +467,7 @@ __global__ void {name}( locks = None kernel_cache = dict() - def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays, mask): + def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c): # parse symbols expr_a, expr_bc = einsum.split(",") expr_b, expr_c = expr_bc.split("->") @@ -608,26 +608,23 @@ __global__ void {name}( instance_cache = dict() @staticmethod - def forward(ctx, einsum, a, b, shape_c, **kwargs): - bench = kwargs['bench'] if 'bench' in kwargs else False + def forward(ctx, einsum, a, b, **kwargs): + bench = kwargs['bench'] if 'bench' in kwargs else False arrays = kwargs['arrays'] if 'arrays' in kwargs else dict() - mask = kwargs['mask'] if 'mask' in kwargs else None - # allocate output - dtype = a.dtype - c = triton.empty(shape_c, dtype=dtype) + mask = kwargs['mask'] if 'mask' in kwargs else None + output = kwargs['output'] if 'output' in kwargs else None # compile einsum instance cache = _einsum.instance_cache - key = (einsum, dtype, - a.stride(), b.stride(), c.stride(), - a.shape, b.shape, c.shape, - mask) + key = (einsum, a.dtype, + a.stride(), b.stride(), output.stride(), + a.shape, b.shape, output.shape, mask) if key not in cache: - cache[key] = _einsum.instance(einsum, dtype, - a.stride(), b.stride(), c.stride(), - a.shape, b.shape, c.shape, arrays, - mask) + cache[key] = _einsum.instance(einsum, a.dtype, + a.stride(), b.stride(), output.stride(), + a.shape, b.shape, arrays, + mask, output.shape) instance = cache[key] - speed = instance.run(a, b, c, bench) + speed = instance.run(a, b, output, bench) # save information in context ctx.flops = instance.flops ctx.sym_a = instance.sym_a @@ -641,7 +638,7 @@ __global__ void {name}( ctx.forward_ms = speed ctx.mask = mask ctx.save_for_backward(a, b) - return c + return output ############################ ## Backward