add score calculator wrapper, forward property lookups on vecenv wrap… (#300)

* add score calculator wrapper, forward property lookups on vecenv wrapper, misc cleanup

* tests

* pylint
This commit is contained in:
Christopher Hesse
2019-03-27 14:36:28 -07:00
committed by Jacob Hilton
parent a08af5d07d
commit ea20c8a034
2 changed files with 20 additions and 10 deletions

View File

@@ -145,8 +145,7 @@ class VecEnvWrapper(VecEnv):
def __init__(self, venv, observation_space=None, action_space=None): def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv self.venv = venv
VecEnv.__init__(self, super().__init__(num_envs=venv.num_envs,
num_envs=venv.num_envs,
observation_space=observation_space or venv.observation_space, observation_space=observation_space or venv.observation_space,
action_space=action_space or venv.action_space) action_space=action_space or venv.action_space)
@@ -170,6 +169,11 @@ class VecEnvWrapper(VecEnv):
def get_images(self): def get_images(self):
return self.venv.get_images() return self.venv.get_images()
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.venv, name)
class VecEnvObservationWrapper(VecEnvWrapper): class VecEnvObservationWrapper(VecEnvWrapper):
@abstractmethod @abstractmethod
def process(self, obs): def process(self, obs):

View File

@@ -5,16 +5,18 @@ import time
from collections import deque from collections import deque
class VecMonitor(VecEnvWrapper): class VecMonitor(VecEnvWrapper):
def __init__(self, venv, filename=None, keep_buf=0): def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
VecEnvWrapper.__init__(self, venv) VecEnvWrapper.__init__(self, venv)
self.eprets = None self.eprets = None
self.eplens = None self.eplens = None
self.epcount = 0 self.epcount = 0
self.tstart = time.time() self.tstart = time.time()
if filename: if filename:
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart}) self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart},
extra_keys=info_keywords)
else: else:
self.results_writer = None self.results_writer = None
self.info_keywords = info_keywords
self.keep_buf = keep_buf self.keep_buf = keep_buf
if self.keep_buf: if self.keep_buf:
self.epret_buf = deque([], maxlen=keep_buf) self.epret_buf = deque([], maxlen=keep_buf)
@@ -30,11 +32,16 @@ class VecMonitor(VecEnvWrapper):
obs, rews, dones, infos = self.venv.step_wait() obs, rews, dones, infos = self.venv.step_wait()
self.eprets += rews self.eprets += rews
self.eplens += 1 self.eplens += 1
newinfos = []
for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)): newinfos = infos[:]
info = info.copy() for i in range(len(dones)):
if done: if dones[i]:
info = infos[i].copy()
ret = self.eprets[i]
eplen = self.eplens[i]
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
for k in self.info_keywords:
epinfo[k] = info[k]
info['episode'] = epinfo info['episode'] = epinfo
if self.keep_buf: if self.keep_buf:
self.epret_buf.append(ret) self.epret_buf.append(ret)
@@ -44,6 +51,5 @@ class VecMonitor(VecEnvWrapper):
self.eplens[i] = 0 self.eplens[i] = 0
if self.results_writer: if self.results_writer:
self.results_writer.write_row(epinfo) self.results_writer.write_row(epinfo)
newinfos.append(info) newinfos[i] = info
return obs, rews, dones, newinfos return obs, rews, dones, newinfos