Rl19 metalearning (#261)

* rl19 metalearning and dict obs

* master merge arch fix

* lint fixes

* view fixes

* load vars tweaks

* user config cleanup

* documentation and revisions

* pass train comm to rl19

* cleanup
This commit is contained in:
Karl Cobbe
2019-03-01 16:41:17 -08:00
committed by Jacob Hilton
parent d9702e7ccb
commit dadc2c2eb6
2 changed files with 3 additions and 5 deletions

View File

@@ -1,6 +1,5 @@
from .vec_env import VecEnvObservationWrapper from .vec_env import VecEnvObservationWrapper
class VecExtractDictObs(VecEnvObservationWrapper): class VecExtractDictObs(VecEnvObservationWrapper):
def __init__(self, venv, key): def __init__(self, venv, key):
self.key = key self.key = key
@@ -8,4 +7,4 @@ class VecExtractDictObs(VecEnvObservationWrapper):
observation_space=venv.observation_space.spaces[self.key]) observation_space=venv.observation_space.spaces[self.key])
def process(self, obs): def process(self, obs):
return obs[self.key] return obs[self.key]

View File

@@ -361,7 +361,7 @@ class Logger(object):
if isinstance(fmt, SeqWriter): if isinstance(fmt, SeqWriter):
fmt.writeseq(map(str, args)) fmt.writeseq(map(str, args))
def configure(dir=None, format_strs=None, comm=None): def configure(dir=None, format_strs=None, comm=None, log_suffix=''):
""" """
If comm is provided, average all numerical stats across that comm If comm is provided, average all numerical stats across that comm
""" """
@@ -373,7 +373,6 @@ def configure(dir=None, format_strs=None, comm=None):
assert isinstance(dir, str) assert isinstance(dir, str)
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
log_suffix = ''
rank = 0 rank = 0
# check environment variables here instead of importing mpi4py # check environment variables here instead of importing mpi4py
# to avoid calling MPI_Init() when this module is imported # to avoid calling MPI_Init() when this module is imported
@@ -381,7 +380,7 @@ def configure(dir=None, format_strs=None, comm=None):
if varname in os.environ: if varname in os.environ:
rank = int(os.environ[varname]) rank = int(os.environ[varname])
if rank > 0: if rank > 0:
log_suffix = "-rank%03i" % rank log_suffix = log_suffix + "-rank%03i" % rank
if format_strs is None: if format_strs is None:
if rank == 0: if rank == 0: