From ab02fae71d8fa3260d8efe3eae861a4ed01fbd5d Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Wed, 30 Jan 2019 16:21:57 -0800 Subject: [PATCH] fixes related to new gym and new flake8 --- baselines/common/distributions.py | 3 ++- baselines/her/ddpg.py | 2 +- baselines/her/rollout.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 554a2f1..8966ee3 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -75,7 +75,8 @@ class CategoricalPdType(PdType): class MultiCategoricalPdType(PdType): def __init__(self, nvec): - self.ncats = nvec + self.ncats = nvec.astype('int32') + assert (self.ncats > 0).all() def pdclass(self): return MultiCategoricalPd def pdfromflat(self, flat): diff --git a/baselines/her/ddpg.py b/baselines/her/ddpg.py index 07317e5..988f14b 100644 --- a/baselines/her/ddpg.py +++ b/baselines/her/ddpg.py @@ -410,7 +410,7 @@ class DDPG(object): logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))] logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))] - if prefix is not '' and not prefix.endswith('/'): + if prefix != '' and not prefix.endswith('/'): return [(prefix + '/' + key, val) for key, val in logs] else: return logs diff --git a/baselines/her/rollout.py b/baselines/her/rollout.py index 4ffeee5..3235ab7 100644 --- a/baselines/her/rollout.py +++ b/baselines/her/rollout.py @@ -163,7 +163,7 @@ class RolloutWorker: logs += [('mean_Q', np.mean(self.Q_history))] logs += [('episode', self.n_episodes)] - if prefix is not '' and not prefix.endswith('/'): + if prefix != '' and not prefix.endswith('/'): return [(prefix + '/' + key, val) for key, val in logs] else: return logs