fixes related to new gym and new flake8

This commit is contained in:
Peter Zhokhov
2019-01-30 16:21:57 -08:00
parent b55eda1dde
commit ab02fae71d
3 changed files with 4 additions and 3 deletions

View File

@@ -75,7 +75,8 @@ class CategoricalPdType(PdType):
class MultiCategoricalPdType(PdType):
def __init__(self, nvec):
self.ncats = nvec
self.ncats = nvec.astype('int32')
assert (self.ncats > 0).all()
def pdclass(self):
return MultiCategoricalPd
def pdfromflat(self, flat):

View File

@@ -410,7 +410,7 @@ class DDPG(object):
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
if prefix is not '' and not prefix.endswith('/'):
if prefix != '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs]
else:
return logs

View File

@@ -163,7 +163,7 @@ class RolloutWorker:
logs += [('mean_Q', np.mean(self.Q_history))]
logs += [('episode', self.n_episodes)]
if prefix is not '' and not prefix.endswith('/'):
if prefix != '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs]
else:
return logs