diff --git a/baselines/common/tests/test_tf_util.py b/baselines/common/tests/test_tf_util.py index daad9d0..929f654 100644 --- a/baselines/common/tests/test_tf_util.py +++ b/baselines/common/tests/test_tf_util.py @@ -18,7 +18,9 @@ def test_function(): initialize() assert lin(2) == 6 + assert lin(x=3) == 9 assert lin(2, 2) == 10 + assert lin(x=2, y=3) == 12 def test_multikwargs(): diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 717b7dc..a2d5df4 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -186,6 +186,7 @@ class _Function(object): if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0): assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method" self.inputs = inputs + self.input_names = {inp.name.split("/")[-1].split(":")[0]: inp for inp in inputs} updates = updates or [] self.update_group = tf.group(*updates) self.outputs_update = list(outputs) + [self.update_group] @@ -197,15 +198,17 @@ class _Function(object): else: feed_dict[inpt] = adjust_shape(inpt, value) - def __call__(self, *args): - assert len(args) <= len(self.inputs), "Too many arguments provided" + def __call__(self, *args, **kwargs): + assert len(args) + len(kwargs) <= len(self.inputs), "Too many arguments provided" feed_dict = {} - # Update the args - for inpt, value in zip(self.inputs, args): - self._feed_input(feed_dict, inpt, value) # Update feed dict with givens. for inpt in self.givens: feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt])) + # Update the args + for inpt, value in zip(self.inputs, args): + self._feed_input(feed_dict, inpt, value) + for inpt_name, value in kwargs.items(): + self._feed_input(feed_dict, self.input_names[inpt_name], value) results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1] return results