Rl19 metalearning (#261)

* rl19 metalearning and dict obs

* master merge arch fix

* lint fixes

* view fixes

* load vars tweaks

* user config cleanup

* documentation and revisions

* pass train comm to rl19

* cleanup
This commit is contained in:
Karl Cobbe
2019-03-01 16:41:17 -08:00
committed by Jacob Hilton
parent d9702e7ccb
commit dadc2c2eb6
2 changed files with 3 additions and 5 deletions

View File

@@ -1,6 +1,5 @@
from .vec_env import VecEnvObservationWrapper
class VecExtractDictObs(VecEnvObservationWrapper):
def __init__(self, venv, key):
self.key = key
@@ -8,4 +7,4 @@ class VecExtractDictObs(VecEnvObservationWrapper):
observation_space=venv.observation_space.spaces[self.key])
def process(self, obs):
return obs[self.key]
return obs[self.key]

View File

@@ -361,7 +361,7 @@ class Logger(object):
if isinstance(fmt, SeqWriter):
fmt.writeseq(map(str, args))
def configure(dir=None, format_strs=None, comm=None):
def configure(dir=None, format_strs=None, comm=None, log_suffix=''):
"""
If comm is provided, average all numerical stats across that comm
"""
@@ -373,7 +373,6 @@ def configure(dir=None, format_strs=None, comm=None):
assert isinstance(dir, str)
os.makedirs(dir, exist_ok=True)
log_suffix = ''
rank = 0
# check environment variables here instead of importing mpi4py
# to avoid calling MPI_Init() when this module is imported
@@ -381,7 +380,7 @@ def configure(dir=None, format_strs=None, comm=None):
if varname in os.environ:
rank = int(os.environ[varname])
if rank > 0:
log_suffix = "-rank%03i" % rank
log_suffix = log_suffix + "-rank%03i" % rank
if format_strs is None:
if rank == 0: