syntax and flake8
This commit is contained in:
@@ -403,7 +403,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
||||
logger.logkv(lossname, lossval)
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
logger.dumpkvs()
|
||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
|
||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
|
||||
checkdir = osp.join(logger.get_dir(), 'checkpoints')
|
||||
os.makedirs(checkdir, exist_ok=True)
|
||||
savepath = osp.join(checkdir, '%.5i'%update)
|
||||
|
@@ -365,7 +365,7 @@ def learn(*,
|
||||
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
|
||||
if MPI is not None:
|
||||
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
||||
else
|
||||
else:
|
||||
listoflrpairs = [lrlocal]
|
||||
|
||||
lens, rews = map(flatten_lists, zip(*listoflrpairs))
|
||||
|
Reference in New Issue
Block a user