Rl19 metalearning (#261)
* rl19 metalearning and dict obs * master merge arch fix * lint fixes * view fixes * load vars tweaks * user config cleanup * documentation and revisions * pass train comm to rl19 * cleanup
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
from .vec_env import VecEnvObservationWrapper
|
from .vec_env import VecEnvObservationWrapper
|
||||||
|
|
||||||
|
|
||||||
class VecExtractDictObs(VecEnvObservationWrapper):
|
class VecExtractDictObs(VecEnvObservationWrapper):
|
||||||
def __init__(self, venv, key):
|
def __init__(self, venv, key):
|
||||||
self.key = key
|
self.key = key
|
||||||
@@ -8,4 +7,4 @@ class VecExtractDictObs(VecEnvObservationWrapper):
|
|||||||
observation_space=venv.observation_space.spaces[self.key])
|
observation_space=venv.observation_space.spaces[self.key])
|
||||||
|
|
||||||
def process(self, obs):
|
def process(self, obs):
|
||||||
return obs[self.key]
|
return obs[self.key]
|
||||||
|
@@ -361,7 +361,7 @@ class Logger(object):
|
|||||||
if isinstance(fmt, SeqWriter):
|
if isinstance(fmt, SeqWriter):
|
||||||
fmt.writeseq(map(str, args))
|
fmt.writeseq(map(str, args))
|
||||||
|
|
||||||
def configure(dir=None, format_strs=None, comm=None):
|
def configure(dir=None, format_strs=None, comm=None, log_suffix=''):
|
||||||
"""
|
"""
|
||||||
If comm is provided, average all numerical stats across that comm
|
If comm is provided, average all numerical stats across that comm
|
||||||
"""
|
"""
|
||||||
@@ -373,7 +373,6 @@ def configure(dir=None, format_strs=None, comm=None):
|
|||||||
assert isinstance(dir, str)
|
assert isinstance(dir, str)
|
||||||
os.makedirs(dir, exist_ok=True)
|
os.makedirs(dir, exist_ok=True)
|
||||||
|
|
||||||
log_suffix = ''
|
|
||||||
rank = 0
|
rank = 0
|
||||||
# check environment variables here instead of importing mpi4py
|
# check environment variables here instead of importing mpi4py
|
||||||
# to avoid calling MPI_Init() when this module is imported
|
# to avoid calling MPI_Init() when this module is imported
|
||||||
@@ -381,7 +380,7 @@ def configure(dir=None, format_strs=None, comm=None):
|
|||||||
if varname in os.environ:
|
if varname in os.environ:
|
||||||
rank = int(os.environ[varname])
|
rank = int(os.environ[varname])
|
||||||
if rank > 0:
|
if rank > 0:
|
||||||
log_suffix = "-rank%03i" % rank
|
log_suffix = log_suffix + "-rank%03i" % rank
|
||||||
|
|
||||||
if format_strs is None:
|
if format_strs is None:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
Reference in New Issue
Block a user