diff --git a/baselines/common/vec_env/vec_remove_dict_obs.py b/baselines/common/vec_env/vec_remove_dict_obs.py index 602b949..a6c4656 100644 --- a/baselines/common/vec_env/vec_remove_dict_obs.py +++ b/baselines/common/vec_env/vec_remove_dict_obs.py @@ -1,6 +1,5 @@ from .vec_env import VecEnvObservationWrapper - class VecExtractDictObs(VecEnvObservationWrapper): def __init__(self, venv, key): self.key = key @@ -8,4 +7,4 @@ class VecExtractDictObs(VecEnvObservationWrapper): observation_space=venv.observation_space.spaces[self.key]) def process(self, obs): - return obs[self.key] \ No newline at end of file + return obs[self.key] diff --git a/baselines/logger.py b/baselines/logger.py index 6c08ca0..a0e75ab 100644 --- a/baselines/logger.py +++ b/baselines/logger.py @@ -361,7 +361,7 @@ class Logger(object): if isinstance(fmt, SeqWriter): fmt.writeseq(map(str, args)) -def configure(dir=None, format_strs=None, comm=None): +def configure(dir=None, format_strs=None, comm=None, log_suffix=''): """ If comm is provided, average all numerical stats across that comm """ @@ -373,7 +373,6 @@ def configure(dir=None, format_strs=None, comm=None): assert isinstance(dir, str) os.makedirs(dir, exist_ok=True) - log_suffix = '' rank = 0 # check environment variables here instead of importing mpi4py # to avoid calling MPI_Init() when this module is imported @@ -381,7 +380,7 @@ def configure(dir=None, format_strs=None, comm=None): if varname in os.environ: rank = int(os.environ[varname]) if rank > 0: - log_suffix = "-rank%03i" % rank + log_suffix = log_suffix + "-rank%03i" % rank if format_strs is None: if rank == 0: