[PYTHON][FUNCTION] Now using common grad output format for both

tensorflow and pytorch
This commit is contained in:
Philippe Tillet
2019-10-29 14:09:40 -04:00
parent 76651a065f
commit d9eacf937c
2 changed files with 18 additions and 5 deletions

View File

@@ -48,16 +48,29 @@ class function(metaclass = function_meta):
# are handled properly
mutex = fw.gen_resource_variable_ops.mutex_v2()
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
with fw.tensorflow.python.ops.control_dependencies([lock]):
with fw.tensorflow.control_dependencies([lock]):
result = cls.forward(ctx, *args, **kwargs)
ctx_registry[result] = ctx
# Find a mapping between ::forward arguments and tensorflow op arguments
remap = dict()
for i, ix in enumerate(result.op.inputs):
for j, jx in enumerate(args):
if ix is jx:
remap[j] = i
# register backward
ctx_registry[result] = ctx
name = result.op.op_def.name
if not cls.registered:
@fw.tensorflow.RegisterGradient(name)
def gradient(op, dy):
with fw.tensorflow.control_dependencies([op]):
return cls.backward(ctx_registry[op.outputs[0]], dy)
y = op.outputs[0]
grad = cls.backward(ctx_registry[y], dy)
# Remap gradient in the right order
ret = [None] * len(op.inputs)
for i in range(len(grad)):
if i in remap:
ret[remap[i]] = grad[i]
# Return
return ret
cls.registered = True
# return result tensor
return result

View File

@@ -226,6 +226,6 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
return da, db, None, None, None, None, None, None, None, None, None, None
return None, da, db, None
einsum = _einsum.apply