fixes related to new gym and new flake8

This commit is contained in:
Peter Zhokhov
2019-01-30 16:21:57 -08:00
parent b55eda1dde
commit ab02fae71d
3 changed files with 4 additions and 3 deletions

View File

@@ -75,7 +75,8 @@ class CategoricalPdType(PdType):
class MultiCategoricalPdType(PdType): class MultiCategoricalPdType(PdType):
def __init__(self, nvec): def __init__(self, nvec):
self.ncats = nvec self.ncats = nvec.astype('int32')
assert (self.ncats > 0).all()
def pdclass(self): def pdclass(self):
return MultiCategoricalPd return MultiCategoricalPd
def pdfromflat(self, flat): def pdfromflat(self, flat):

View File

@@ -410,7 +410,7 @@ class DDPG(object):
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))] logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))] logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
if prefix is not '' and not prefix.endswith('/'): if prefix != '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs] return [(prefix + '/' + key, val) for key, val in logs]
else: else:
return logs return logs

View File

@@ -163,7 +163,7 @@ class RolloutWorker:
logs += [('mean_Q', np.mean(self.Q_history))] logs += [('mean_Q', np.mean(self.Q_history))]
logs += [('episode', self.n_episodes)] logs += [('episode', self.n_episodes)]
if prefix is not '' and not prefix.endswith('/'): if prefix != '' and not prefix.endswith('/'):
return [(prefix + '/' + key, val) for key, val in logs] return [(prefix + '/' + key, val) for key, val in logs]
else: else:
return logs return logs