* fix #795: Making tf_util._Function consistent The fix involves using the placeholder name to crossreference passed kwargs values, just like the tf_util.function expects. Also, the givens are updated before the parameters to make it behave like it's supposed to. * test: Adding test for issue #795
This commit is contained in:
@@ -18,7 +18,9 @@ def test_function():
|
|||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
assert lin(2) == 6
|
assert lin(2) == 6
|
||||||
|
assert lin(x=3) == 9
|
||||||
assert lin(2, 2) == 10
|
assert lin(2, 2) == 10
|
||||||
|
assert lin(x=2, y=3) == 12
|
||||||
|
|
||||||
|
|
||||||
def test_multikwargs():
|
def test_multikwargs():
|
||||||
|
@@ -186,6 +186,7 @@ class _Function(object):
|
|||||||
if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
|
if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
|
||||||
assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
|
assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
|
self.input_names = {inp.name.split("/")[-1].split(":")[0]: inp for inp in inputs}
|
||||||
updates = updates or []
|
updates = updates or []
|
||||||
self.update_group = tf.group(*updates)
|
self.update_group = tf.group(*updates)
|
||||||
self.outputs_update = list(outputs) + [self.update_group]
|
self.outputs_update = list(outputs) + [self.update_group]
|
||||||
@@ -197,15 +198,17 @@ class _Function(object):
|
|||||||
else:
|
else:
|
||||||
feed_dict[inpt] = adjust_shape(inpt, value)
|
feed_dict[inpt] = adjust_shape(inpt, value)
|
||||||
|
|
||||||
def __call__(self, *args):
|
def __call__(self, *args, **kwargs):
|
||||||
assert len(args) <= len(self.inputs), "Too many arguments provided"
|
assert len(args) + len(kwargs) <= len(self.inputs), "Too many arguments provided"
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
# Update the args
|
|
||||||
for inpt, value in zip(self.inputs, args):
|
|
||||||
self._feed_input(feed_dict, inpt, value)
|
|
||||||
# Update feed dict with givens.
|
# Update feed dict with givens.
|
||||||
for inpt in self.givens:
|
for inpt in self.givens:
|
||||||
feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt]))
|
feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt]))
|
||||||
|
# Update the args
|
||||||
|
for inpt, value in zip(self.inputs, args):
|
||||||
|
self._feed_input(feed_dict, inpt, value)
|
||||||
|
for inpt_name, value in kwargs.items():
|
||||||
|
self._feed_input(feed_dict, self.input_names[inpt_name], value)
|
||||||
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user