[PYTHON][FUNCTION] Now using common grad output format for both
tensorflow and pytorch
This commit is contained in:
@@ -48,16 +48,29 @@ class function(metaclass = function_meta):
|
|||||||
# are handled properly
|
# are handled properly
|
||||||
mutex = fw.gen_resource_variable_ops.mutex_v2()
|
mutex = fw.gen_resource_variable_ops.mutex_v2()
|
||||||
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
|
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
|
||||||
with fw.tensorflow.python.ops.control_dependencies([lock]):
|
with fw.tensorflow.control_dependencies([lock]):
|
||||||
result = cls.forward(ctx, *args, **kwargs)
|
result = cls.forward(ctx, *args, **kwargs)
|
||||||
ctx_registry[result] = ctx
|
# Find a mapping between ::forward arguments and tensorflow op arguments
|
||||||
|
remap = dict()
|
||||||
|
for i, ix in enumerate(result.op.inputs):
|
||||||
|
for j, jx in enumerate(args):
|
||||||
|
if ix is jx:
|
||||||
|
remap[j] = i
|
||||||
# register backward
|
# register backward
|
||||||
|
ctx_registry[result] = ctx
|
||||||
name = result.op.op_def.name
|
name = result.op.op_def.name
|
||||||
if not cls.registered:
|
if not cls.registered:
|
||||||
@fw.tensorflow.RegisterGradient(name)
|
@fw.tensorflow.RegisterGradient(name)
|
||||||
def gradient(op, dy):
|
def gradient(op, dy):
|
||||||
with fw.tensorflow.control_dependencies([op]):
|
y = op.outputs[0]
|
||||||
return cls.backward(ctx_registry[op.outputs[0]], dy)
|
grad = cls.backward(ctx_registry[y], dy)
|
||||||
|
# Remap gradient in the right order
|
||||||
|
ret = [None] * len(op.inputs)
|
||||||
|
for i in range(len(grad)):
|
||||||
|
if i in remap:
|
||||||
|
ret[remap[i]] = grad[i]
|
||||||
|
# Return
|
||||||
|
return ret
|
||||||
cls.registered = True
|
cls.registered = True
|
||||||
# return result tensor
|
# return result tensor
|
||||||
return result
|
return result
|
||||||
|
@@ -226,6 +226,6 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
|
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
|
||||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
|
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
|
||||||
|
|
||||||
return da, db, None, None, None, None, None, None, None, None, None, None
|
return None, da, db, None
|
||||||
|
|
||||||
einsum = _einsum.apply
|
einsum = _einsum.apply
|
Reference in New Issue
Block a user