[PYTHON][TENSORFLOW] Signature of function.forward() does not have to
match signature of kernel anymore
This commit is contained in:
@@ -62,11 +62,16 @@ class function(metaclass = function_meta):
|
|||||||
|
|
||||||
# Find a mapping between ::forward arguments and tensorflow op arguments
|
# Find a mapping between ::forward arguments and tensorflow op arguments
|
||||||
op = result[0].op
|
op = result[0].op
|
||||||
remap = dict()
|
remap_in = dict()
|
||||||
for i, ix in enumerate(op.inputs):
|
for i, ix in enumerate(op.inputs):
|
||||||
for j, jx in enumerate(args):
|
for j, jx in enumerate(args):
|
||||||
if ix is jx:
|
if ix is jx:
|
||||||
remap[j] = i
|
remap_in[j] = i
|
||||||
|
remap_out = []
|
||||||
|
for i, ix in enumerate(result):
|
||||||
|
for j, jx in enumerate(op.outputs):
|
||||||
|
if ix is jx:
|
||||||
|
remap_out.append(j)
|
||||||
|
|
||||||
# Register backward pass
|
# Register backward pass
|
||||||
ctx_registry[op] = ctx
|
ctx_registry[op] = ctx
|
||||||
@@ -75,14 +80,16 @@ class function(metaclass = function_meta):
|
|||||||
@fw.tensorflow.RegisterGradient(name)
|
@fw.tensorflow.RegisterGradient(name)
|
||||||
def gradient(op, *dy):
|
def gradient(op, *dy):
|
||||||
dy = dy if len(dy) > 1 else dy[0]
|
dy = dy if len(dy) > 1 else dy[0]
|
||||||
grad = cls.backward(ctx_registry[op], dy)
|
# Remap gradient inputs in the right order
|
||||||
|
grads = [dy[i] for i in remap_out]
|
||||||
|
# Execute gradient function
|
||||||
|
grad = cls.backward(ctx_registry[op], grads)
|
||||||
grad = function.extract_tf_tensors(grad, 'backward')
|
grad = function.extract_tf_tensors(grad, 'backward')
|
||||||
|
|
||||||
# Remap gradient in the right order
|
# Remap gradient in the right order
|
||||||
ret = [None] * len(op.inputs)
|
ret = [None] * len(op.inputs)
|
||||||
for i in range(len(grad)):
|
for i in range(len(grad)):
|
||||||
if i in remap:
|
if i in remap_in:
|
||||||
ret[remap[i]] = grad[i]
|
ret[remap_in[i]] = grad[i]
|
||||||
# Return
|
# Return
|
||||||
return ret
|
return ret
|
||||||
cls.registered = True
|
cls.registered = True
|
||||||
|
@@ -225,8 +225,10 @@ class kernel:
|
|||||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||||
# call framework function
|
# call framework function
|
||||||
if fw.has_tensorflow():
|
if fw.has_tensorflow():
|
||||||
|
empty = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
|
||||||
|
if len(empty) != len(self.outputs):
|
||||||
|
raise ValueError('Number of empty arguments does not much number of outputs provided')
|
||||||
# operands
|
# operands
|
||||||
outputs = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
|
|
||||||
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]]
|
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args[:-1]]
|
||||||
# output data types
|
# output data types
|
||||||
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
|
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
|
||||||
@@ -235,10 +237,12 @@ class kernel:
|
|||||||
kwargs['T' + str(i)] = x.dtype
|
kwargs['T' + str(i)] = x.dtype
|
||||||
# launch
|
# launch
|
||||||
ret = self.fw_op(*operands, **kwargs)
|
ret = self.fw_op(*operands, **kwargs)
|
||||||
assert len(ret) == len(outputs)
|
# fill empty tensors with corresponding values
|
||||||
# record results
|
for j, y in enumerate(ret[0].op.op_def.output_arg):
|
||||||
for i in range(len(outputs)):
|
for i, x in enumerate(ret[0].op.op_def.input_arg):
|
||||||
outputs[i].tensor = ret[i]
|
if y.name + '_shape' == x.name:
|
||||||
|
empty[i].tensor = ret[j]
|
||||||
|
# store timing information
|
||||||
if bench > 0:
|
if bench > 0:
|
||||||
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||||
elif fw.has_torch():
|
elif fw.has_torch():
|
||||||
|
@@ -109,8 +109,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grads):
|
def backward(ctx, dy):
|
||||||
dy, dmean, dvar = grads
|
|
||||||
# retrieve info
|
# retrieve info
|
||||||
x, gamma, beta, mean, var = ctx.saved_tensors
|
x, gamma, beta, mean, var = ctx.saved_tensors
|
||||||
eps = ctx.eps
|
eps = ctx.eps
|
||||||
|
Reference in New Issue
Block a user