add score calculator wrapper, forward property lookups on vecenv wrap… (#300)
* add score calculator wrapper, forward property lookups on vecenv wrapper, misc cleanup * tests * pylint
This commit is contained in:
committed by
Jacob Hilton
parent
a08af5d07d
commit
ea20c8a034
@@ -145,8 +145,7 @@ class VecEnvWrapper(VecEnv):
|
|||||||
|
|
||||||
def __init__(self, venv, observation_space=None, action_space=None):
|
def __init__(self, venv, observation_space=None, action_space=None):
|
||||||
self.venv = venv
|
self.venv = venv
|
||||||
VecEnv.__init__(self,
|
super().__init__(num_envs=venv.num_envs,
|
||||||
num_envs=venv.num_envs,
|
|
||||||
observation_space=observation_space or venv.observation_space,
|
observation_space=observation_space or venv.observation_space,
|
||||||
action_space=action_space or venv.action_space)
|
action_space=action_space or venv.action_space)
|
||||||
|
|
||||||
@@ -170,6 +169,11 @@ class VecEnvWrapper(VecEnv):
|
|||||||
def get_images(self):
|
def get_images(self):
|
||||||
return self.venv.get_images()
|
return self.venv.get_images()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name.startswith('_'):
|
||||||
|
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
|
||||||
|
return getattr(self.venv, name)
|
||||||
|
|
||||||
class VecEnvObservationWrapper(VecEnvWrapper):
|
class VecEnvObservationWrapper(VecEnvWrapper):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(self, obs):
|
def process(self, obs):
|
||||||
|
@@ -5,16 +5,18 @@ import time
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
class VecMonitor(VecEnvWrapper):
|
class VecMonitor(VecEnvWrapper):
|
||||||
def __init__(self, venv, filename=None, keep_buf=0):
|
def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
|
||||||
VecEnvWrapper.__init__(self, venv)
|
VecEnvWrapper.__init__(self, venv)
|
||||||
self.eprets = None
|
self.eprets = None
|
||||||
self.eplens = None
|
self.eplens = None
|
||||||
self.epcount = 0
|
self.epcount = 0
|
||||||
self.tstart = time.time()
|
self.tstart = time.time()
|
||||||
if filename:
|
if filename:
|
||||||
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
|
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart},
|
||||||
|
extra_keys=info_keywords)
|
||||||
else:
|
else:
|
||||||
self.results_writer = None
|
self.results_writer = None
|
||||||
|
self.info_keywords = info_keywords
|
||||||
self.keep_buf = keep_buf
|
self.keep_buf = keep_buf
|
||||||
if self.keep_buf:
|
if self.keep_buf:
|
||||||
self.epret_buf = deque([], maxlen=keep_buf)
|
self.epret_buf = deque([], maxlen=keep_buf)
|
||||||
@@ -30,11 +32,16 @@ class VecMonitor(VecEnvWrapper):
|
|||||||
obs, rews, dones, infos = self.venv.step_wait()
|
obs, rews, dones, infos = self.venv.step_wait()
|
||||||
self.eprets += rews
|
self.eprets += rews
|
||||||
self.eplens += 1
|
self.eplens += 1
|
||||||
newinfos = []
|
|
||||||
for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)):
|
newinfos = infos[:]
|
||||||
info = info.copy()
|
for i in range(len(dones)):
|
||||||
if done:
|
if dones[i]:
|
||||||
|
info = infos[i].copy()
|
||||||
|
ret = self.eprets[i]
|
||||||
|
eplen = self.eplens[i]
|
||||||
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
|
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
|
||||||
|
for k in self.info_keywords:
|
||||||
|
epinfo[k] = info[k]
|
||||||
info['episode'] = epinfo
|
info['episode'] = epinfo
|
||||||
if self.keep_buf:
|
if self.keep_buf:
|
||||||
self.epret_buf.append(ret)
|
self.epret_buf.append(ret)
|
||||||
@@ -44,6 +51,5 @@ class VecMonitor(VecEnvWrapper):
|
|||||||
self.eplens[i] = 0
|
self.eplens[i] = 0
|
||||||
if self.results_writer:
|
if self.results_writer:
|
||||||
self.results_writer.write_row(epinfo)
|
self.results_writer.write_row(epinfo)
|
||||||
newinfos.append(info)
|
newinfos[i] = info
|
||||||
|
|
||||||
return obs, rews, dones, newinfos
|
return obs, rews, dones, newinfos
|
||||||
|
Reference in New Issue
Block a user