From ab02fae71d8fa3260d8efe3eae861a4ed01fbd5d Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Wed, 30 Jan 2019 16:21:57 -0800 Subject: [PATCH 1/4] fixes related to new gym and new flake8 --- baselines/common/distributions.py | 3 ++- baselines/her/ddpg.py | 2 +- baselines/her/rollout.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 554a2f1..8966ee3 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -75,7 +75,8 @@ class CategoricalPdType(PdType): class MultiCategoricalPdType(PdType): def __init__(self, nvec): - self.ncats = nvec + self.ncats = nvec.astype('int32') + assert (self.ncats > 0).all() def pdclass(self): return MultiCategoricalPd def pdfromflat(self, flat): diff --git a/baselines/her/ddpg.py b/baselines/her/ddpg.py index 07317e5..988f14b 100644 --- a/baselines/her/ddpg.py +++ b/baselines/her/ddpg.py @@ -410,7 +410,7 @@ class DDPG(object): logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))] logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))] - if prefix is not '' and not prefix.endswith('/'): + if prefix != '' and not prefix.endswith('/'): return [(prefix + '/' + key, val) for key, val in logs] else: return logs diff --git a/baselines/her/rollout.py b/baselines/her/rollout.py index 4ffeee5..3235ab7 100644 --- a/baselines/her/rollout.py +++ b/baselines/her/rollout.py @@ -163,7 +163,7 @@ class RolloutWorker: logs += [('mean_Q', np.mean(self.Q_history))] logs += [('episode', self.n_episodes)] - if prefix is not '' and not prefix.endswith('/'): + if prefix != '' and not prefix.endswith('/'): return [(prefix + '/' + key, val) for key, val in logs] else: return logs From 5b41c926c7a852df3f0928afdf2429f96a3965cb Mon Sep 17 00:00:00 2001 From: Rishav1 Date: Thu, 31 Jan 2019 19:23:38 +0100 Subject: [PATCH 2/4] fix #795: Making tf_util._Function consistent (#796) * fix #795: Making tf_util._Function consistent The fix involves using the placeholder name to crossreference passed kwargs values, just like the tf_util.function expects. Also, the givens are updated before the parameters to make it behave like it's supposed to. * test: Adding test for issue #795 --- baselines/common/tests/test_tf_util.py | 2 ++ baselines/common/tf_util.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/baselines/common/tests/test_tf_util.py b/baselines/common/tests/test_tf_util.py index daad9d0..929f654 100644 --- a/baselines/common/tests/test_tf_util.py +++ b/baselines/common/tests/test_tf_util.py @@ -18,7 +18,9 @@ def test_function(): initialize() assert lin(2) == 6 + assert lin(x=3) == 9 assert lin(2, 2) == 10 + assert lin(x=2, y=3) == 12 def test_multikwargs(): diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 717b7dc..a2d5df4 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -186,6 +186,7 @@ class _Function(object): if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0): assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method" self.inputs = inputs + self.input_names = {inp.name.split("/")[-1].split(":")[0]: inp for inp in inputs} updates = updates or [] self.update_group = tf.group(*updates) self.outputs_update = list(outputs) + [self.update_group] @@ -197,15 +198,17 @@ class _Function(object): else: feed_dict[inpt] = adjust_shape(inpt, value) - def __call__(self, *args): - assert len(args) <= len(self.inputs), "Too many arguments provided" + def __call__(self, *args, **kwargs): + assert len(args) + len(kwargs) <= len(self.inputs), "Too many arguments provided" feed_dict = {} - # Update the args - for inpt, value in zip(self.inputs, args): - self._feed_input(feed_dict, inpt, value) # Update feed dict with givens. for inpt in self.givens: feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt])) + # Update the args + for inpt, value in zip(self.inputs, args): + self._feed_input(feed_dict, inpt, value) + for inpt_name, value in kwargs.items(): + self._feed_input(feed_dict, self.input_names[inpt_name], value) results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1] return results From adc4388f6b95dc0e34e418b088d2a82e3c0096e6 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Wed, 27 Feb 2019 12:49:40 -0800 Subject: [PATCH 3/4] fixes to catch changes in gym --- baselines/common/tests/util.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/baselines/common/tests/util.py b/baselines/common/tests/util.py index 51b9d0f..ce44d50 100644 --- a/baselines/common/tests/util.py +++ b/baselines/common/tests/util.py @@ -1,6 +1,5 @@ import tensorflow as tf import numpy as np -from gym.spaces import np_random from baselines.common.vec_env.dummy_vec_env import DummyVecEnv N_TRIALS = 10000 @@ -8,8 +7,6 @@ N_EPISODES = 100 def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS): np.random.seed(0) - np_random.seed(0) - env = DummyVecEnv([env_fn]) From 675b100190f7cb198166abf4f2cba7da93d6d4e4 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Wed, 27 Feb 2019 14:22:24 -0800 Subject: [PATCH 4/4] raised the tolerance on the test_microbatches test --- baselines/ppo2/test_microbatches.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/baselines/ppo2/test_microbatches.py b/baselines/ppo2/test_microbatches.py index 291c2d2..829e0a9 100644 --- a/baselines/ppo2/test_microbatches.py +++ b/baselines/ppo2/test_microbatches.py @@ -25,10 +25,11 @@ def test_microbatches(): env_test = DummyVecEnv([env_fn]) sess_test = make_session(make_default=True, graph=tf.Graph()) learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2)) + # learn_fn(env=env_test) vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()} for v in vars_ref: - np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3) + np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=3e-3) if __name__ == '__main__': test_microbatches()