[PYTHON][FUNCTION] Now using common grad output format for both
tensorflow and pytorch
This commit is contained in:
@@ -48,16 +48,29 @@ class function(metaclass = function_meta):
|
||||
# are handled properly
|
||||
mutex = fw.gen_resource_variable_ops.mutex_v2()
|
||||
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
|
||||
with fw.tensorflow.python.ops.control_dependencies([lock]):
|
||||
with fw.tensorflow.control_dependencies([lock]):
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
ctx_registry[result] = ctx
|
||||
# Find a mapping between ::forward arguments and tensorflow op arguments
|
||||
remap = dict()
|
||||
for i, ix in enumerate(result.op.inputs):
|
||||
for j, jx in enumerate(args):
|
||||
if ix is jx:
|
||||
remap[j] = i
|
||||
# register backward
|
||||
ctx_registry[result] = ctx
|
||||
name = result.op.op_def.name
|
||||
if not cls.registered:
|
||||
@fw.tensorflow.RegisterGradient(name)
|
||||
def gradient(op, dy):
|
||||
with fw.tensorflow.control_dependencies([op]):
|
||||
return cls.backward(ctx_registry[op.outputs[0]], dy)
|
||||
y = op.outputs[0]
|
||||
grad = cls.backward(ctx_registry[y], dy)
|
||||
# Remap gradient in the right order
|
||||
ret = [None] * len(op.inputs)
|
||||
for i in range(len(grad)):
|
||||
if i in remap:
|
||||
ret[remap[i]] = grad[i]
|
||||
# Return
|
||||
return ret
|
||||
cls.registered = True
|
||||
# return result tensor
|
||||
return result
|
||||
|
@@ -226,6 +226,6 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
||||
da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench)
|
||||
db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench)
|
||||
|
||||
return da, db, None, None, None, None, None, None, None, None, None, None
|
||||
return None, da, db, None
|
||||
|
||||
einsum = _einsum.apply
|
Reference in New Issue
Block a user