diff --git a/python/triton/function.py b/python/triton/function.py index 213f9f16e..43c8aa7b9 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -62,11 +62,16 @@ class function(metaclass = function_meta): # Find a mapping between ::forward arguments and tensorflow op arguments op = result[0].op - remap = dict() + remap_in = dict() for i, ix in enumerate(op.inputs): for j, jx in enumerate(args): if ix is jx: - remap[j] = i + remap_in[j] = i + remap_out = [] + for i, ix in enumerate(result): + for j, jx in enumerate(op.outputs): + if ix is jx: + remap_out.append(j) # Register backward pass ctx_registry[op] = ctx @@ -75,14 +80,16 @@ class function(metaclass = function_meta): @fw.tensorflow.RegisterGradient(name) def gradient(op, *dy): dy = dy if len(dy) > 1 else dy[0] - grad = cls.backward(ctx_registry[op], dy) + # Remap gradient inputs in the right order + grads = [dy[i] for i in remap_out] + # Execute gradient function + grad = cls.backward(ctx_registry[op], grads) grad = function.extract_tf_tensors(grad, 'backward') - # Remap gradient in the right order ret = [None] * len(op.inputs) for i in range(len(grad)): - if i in remap: - ret[remap[i]] = grad[i] + if i in remap_in: + ret[remap_in[i]] = grad[i] # Return return ret cls.registered = True diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 59a797ec8..ea90e339b 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -225,8 +225,10 @@ class kernel: bench_id = libtriton.make_scalar_id() if bench > 0 else -1 # call framework function if fw.has_tensorflow(): + empty = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)] + if len(empty) != len(self.outputs): + raise ValueError('Number of empty arguments does not much number of outputs provided') # operands - outputs = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)] operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]] # output data types kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id} @@ -235,10 +237,12 @@ class kernel: kwargs['T' + str(i)] = x.dtype # launch ret = self.fw_op(*operands, **kwargs) - assert len(ret) == len(outputs) - # record results - for i in range(len(outputs)): - outputs[i].tensor = ret[i] + # fill empty tensors with corresponding values + for j, y in enumerate(ret[0].op.op_def.output_arg): + for i, x in enumerate(ret[0].op.op_def.input_arg): + if y.name + '_shape' == x.name: + empty[i].tensor = ret[j] + # store timing information if bench > 0: bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id) elif fw.has_torch(): diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index 82a2b9fe2..9bb3450c0 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -109,8 +109,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB, return y @staticmethod - def backward(ctx, grads): - dy, dmean, dvar = grads + def backward(ctx, dy): # retrieve info x, gamma, beta, mean, var = ctx.saved_tensors eps = ctx.eps