Rl19 metalearning (#261)
* rl19 metalearning and dict obs * master merge arch fix * lint fixes * view fixes * load vars tweaks * user config cleanup * documentation and revisions * pass train comm to rl19 * cleanup
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from .vec_env import VecEnvObservationWrapper
|
||||
|
||||
|
||||
class VecExtractDictObs(VecEnvObservationWrapper):
|
||||
def __init__(self, venv, key):
|
||||
self.key = key
|
||||
@@ -8,4 +7,4 @@ class VecExtractDictObs(VecEnvObservationWrapper):
|
||||
observation_space=venv.observation_space.spaces[self.key])
|
||||
|
||||
def process(self, obs):
|
||||
return obs[self.key]
|
||||
return obs[self.key]
|
||||
|
@@ -361,7 +361,7 @@ class Logger(object):
|
||||
if isinstance(fmt, SeqWriter):
|
||||
fmt.writeseq(map(str, args))
|
||||
|
||||
def configure(dir=None, format_strs=None, comm=None):
|
||||
def configure(dir=None, format_strs=None, comm=None, log_suffix=''):
|
||||
"""
|
||||
If comm is provided, average all numerical stats across that comm
|
||||
"""
|
||||
@@ -373,7 +373,6 @@ def configure(dir=None, format_strs=None, comm=None):
|
||||
assert isinstance(dir, str)
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
log_suffix = ''
|
||||
rank = 0
|
||||
# check environment variables here instead of importing mpi4py
|
||||
# to avoid calling MPI_Init() when this module is imported
|
||||
@@ -381,7 +380,7 @@ def configure(dir=None, format_strs=None, comm=None):
|
||||
if varname in os.environ:
|
||||
rank = int(os.environ[varname])
|
||||
if rank > 0:
|
||||
log_suffix = "-rank%03i" % rank
|
||||
log_suffix = log_suffix + "-rank%03i" % rank
|
||||
|
||||
if format_strs is None:
|
||||
if rank == 0:
|
||||
|
Reference in New Issue
Block a user