[PYTHON][TENSORFLOW] Signature of function.forward() does not have to

match signature of kernel anymore
This commit is contained in:
Philippe Tillet
2019-10-30 20:29:23 -04:00
parent e0fe8d9058
commit 93a86d4fc6
3 changed files with 23 additions and 13 deletions

View File

@@ -62,11 +62,16 @@ class function(metaclass = function_meta):
# Find a mapping between ::forward arguments and tensorflow op arguments # Find a mapping between ::forward arguments and tensorflow op arguments
op = result[0].op op = result[0].op
remap = dict() remap_in = dict()
for i, ix in enumerate(op.inputs): for i, ix in enumerate(op.inputs):
for j, jx in enumerate(args): for j, jx in enumerate(args):
if ix is jx: if ix is jx:
remap[j] = i remap_in[j] = i
remap_out = []
for i, ix in enumerate(result):
for j, jx in enumerate(op.outputs):
if ix is jx:
remap_out.append(j)
# Register backward pass # Register backward pass
ctx_registry[op] = ctx ctx_registry[op] = ctx
@@ -75,14 +80,16 @@ class function(metaclass = function_meta):
@fw.tensorflow.RegisterGradient(name) @fw.tensorflow.RegisterGradient(name)
def gradient(op, *dy): def gradient(op, *dy):
dy = dy if len(dy) > 1 else dy[0] dy = dy if len(dy) > 1 else dy[0]
grad = cls.backward(ctx_registry[op], dy) # Remap gradient inputs in the right order
grads = [dy[i] for i in remap_out]
# Execute gradient function
grad = cls.backward(ctx_registry[op], grads)
grad = function.extract_tf_tensors(grad, 'backward') grad = function.extract_tf_tensors(grad, 'backward')
# Remap gradient in the right order # Remap gradient in the right order
ret = [None] * len(op.inputs) ret = [None] * len(op.inputs)
for i in range(len(grad)): for i in range(len(grad)):
if i in remap: if i in remap_in:
ret[remap[i]] = grad[i] ret[remap_in[i]] = grad[i]
# Return # Return
return ret return ret
cls.registered = True cls.registered = True

View File

@@ -225,8 +225,10 @@ class kernel:
bench_id = libtriton.make_scalar_id() if bench > 0 else -1 bench_id = libtriton.make_scalar_id() if bench > 0 else -1
# call framework function # call framework function
if fw.has_tensorflow(): if fw.has_tensorflow():
empty = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
if len(empty) != len(self.outputs):
raise ValueError('Number of empty arguments does not much number of outputs provided')
# operands # operands
outputs = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]] operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]]
# output data types # output data types
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id} kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
@@ -235,10 +237,12 @@ class kernel:
kwargs['T' + str(i)] = x.dtype kwargs['T' + str(i)] = x.dtype
# launch # launch
ret = self.fw_op(*operands, **kwargs) ret = self.fw_op(*operands, **kwargs)
assert len(ret) == len(outputs) # fill empty tensors with corresponding values
# record results for j, y in enumerate(ret[0].op.op_def.output_arg):
for i in range(len(outputs)): for i, x in enumerate(ret[0].op.op_def.input_arg):
outputs[i].tensor = ret[i] if y.name + '_shape' == x.name:
empty[i].tensor = ret[j]
# store timing information
if bench > 0: if bench > 0:
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id) bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
elif fw.has_torch(): elif fw.has_torch():

View File

@@ -109,8 +109,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
return y return y
@staticmethod @staticmethod
def backward(ctx, grads): def backward(ctx, dy):
dy, dmean, dvar = grads
# retrieve info # retrieve info
x, gamma, beta, mean, var = ctx.saved_tensors x, gamma, beta, mean, var = ctx.saved_tensors
eps = ctx.eps eps = ctx.eps