fixes related to new gym and new flake8
This commit is contained in:
@@ -75,7 +75,8 @@ class CategoricalPdType(PdType):
|
|||||||
|
|
||||||
class MultiCategoricalPdType(PdType):
|
class MultiCategoricalPdType(PdType):
|
||||||
def __init__(self, nvec):
|
def __init__(self, nvec):
|
||||||
self.ncats = nvec
|
self.ncats = nvec.astype('int32')
|
||||||
|
assert (self.ncats > 0).all()
|
||||||
def pdclass(self):
|
def pdclass(self):
|
||||||
return MultiCategoricalPd
|
return MultiCategoricalPd
|
||||||
def pdfromflat(self, flat):
|
def pdfromflat(self, flat):
|
||||||
|
@@ -410,7 +410,7 @@ class DDPG(object):
|
|||||||
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
|
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
|
||||||
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
|
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
|
||||||
|
|
||||||
if prefix is not '' and not prefix.endswith('/'):
|
if prefix != '' and not prefix.endswith('/'):
|
||||||
return [(prefix + '/' + key, val) for key, val in logs]
|
return [(prefix + '/' + key, val) for key, val in logs]
|
||||||
else:
|
else:
|
||||||
return logs
|
return logs
|
||||||
|
@@ -163,7 +163,7 @@ class RolloutWorker:
|
|||||||
logs += [('mean_Q', np.mean(self.Q_history))]
|
logs += [('mean_Q', np.mean(self.Q_history))]
|
||||||
logs += [('episode', self.n_episodes)]
|
logs += [('episode', self.n_episodes)]
|
||||||
|
|
||||||
if prefix is not '' and not prefix.endswith('/'):
|
if prefix != '' and not prefix.endswith('/'):
|
||||||
return [(prefix + '/' + key, val) for key, val in logs]
|
return [(prefix + '/' + key, val) for key, val in logs]
|
||||||
else:
|
else:
|
||||||
return logs
|
return logs
|
||||||
|
Reference in New Issue
Block a user