diff --git a/baselines/common/vec_env/vec_env.py b/baselines/common/vec_env/vec_env.py index 7aa7878..fc6098e 100644 --- a/baselines/common/vec_env/vec_env.py +++ b/baselines/common/vec_env/vec_env.py @@ -145,8 +145,7 @@ class VecEnvWrapper(VecEnv): def __init__(self, venv, observation_space=None, action_space=None): self.venv = venv - VecEnv.__init__(self, - num_envs=venv.num_envs, + super().__init__(num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, action_space=action_space or venv.action_space) @@ -170,6 +169,11 @@ class VecEnvWrapper(VecEnv): def get_images(self): return self.venv.get_images() + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError("attempted to get missing private attribute '{}'".format(name)) + return getattr(self.venv, name) + class VecEnvObservationWrapper(VecEnvWrapper): @abstractmethod def process(self, obs): diff --git a/baselines/common/vec_env/vec_monitor.py b/baselines/common/vec_env/vec_monitor.py index 6e67378..a7b1ce4 100644 --- a/baselines/common/vec_env/vec_monitor.py +++ b/baselines/common/vec_env/vec_monitor.py @@ -5,16 +5,18 @@ import time from collections import deque class VecMonitor(VecEnvWrapper): - def __init__(self, venv, filename=None, keep_buf=0): + def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()): VecEnvWrapper.__init__(self, venv) self.eprets = None self.eplens = None self.epcount = 0 self.tstart = time.time() if filename: - self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}) + self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}, + extra_keys=info_keywords) else: self.results_writer = None + self.info_keywords = info_keywords self.keep_buf = keep_buf if self.keep_buf: self.epret_buf = deque([], maxlen=keep_buf) @@ -30,11 +32,16 @@ class VecMonitor(VecEnvWrapper): obs, rews, dones, infos = self.venv.step_wait() self.eprets += rews self.eplens += 1 - newinfos = [] - for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)): - info = info.copy() - if done: + + newinfos = infos[:] + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + ret = self.eprets[i] + eplen = self.eplens[i] epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} + for k in self.info_keywords: + epinfo[k] = info[k] info['episode'] = epinfo if self.keep_buf: self.epret_buf.append(ret) @@ -44,6 +51,5 @@ class VecMonitor(VecEnvWrapper): self.eplens[i] = 0 if self.results_writer: self.results_writer.write_row(epinfo) - newinfos.append(info) - + newinfos[i] = info return obs, rews, dones, newinfos