* add score calculator wrapper, forward property lookups on vecenv wrapper, misc cleanup * tests * pylint
56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
from . import VecEnvWrapper
|
|
from baselines.bench.monitor import ResultsWriter
|
|
import numpy as np
|
|
import time
|
|
from collections import deque
|
|
|
|
class VecMonitor(VecEnvWrapper):
|
|
def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
|
|
VecEnvWrapper.__init__(self, venv)
|
|
self.eprets = None
|
|
self.eplens = None
|
|
self.epcount = 0
|
|
self.tstart = time.time()
|
|
if filename:
|
|
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart},
|
|
extra_keys=info_keywords)
|
|
else:
|
|
self.results_writer = None
|
|
self.info_keywords = info_keywords
|
|
self.keep_buf = keep_buf
|
|
if self.keep_buf:
|
|
self.epret_buf = deque([], maxlen=keep_buf)
|
|
self.eplen_buf = deque([], maxlen=keep_buf)
|
|
|
|
def reset(self):
|
|
obs = self.venv.reset()
|
|
self.eprets = np.zeros(self.num_envs, 'f')
|
|
self.eplens = np.zeros(self.num_envs, 'i')
|
|
return obs
|
|
|
|
def step_wait(self):
|
|
obs, rews, dones, infos = self.venv.step_wait()
|
|
self.eprets += rews
|
|
self.eplens += 1
|
|
|
|
newinfos = infos[:]
|
|
for i in range(len(dones)):
|
|
if dones[i]:
|
|
info = infos[i].copy()
|
|
ret = self.eprets[i]
|
|
eplen = self.eplens[i]
|
|
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
|
|
for k in self.info_keywords:
|
|
epinfo[k] = info[k]
|
|
info['episode'] = epinfo
|
|
if self.keep_buf:
|
|
self.epret_buf.append(ret)
|
|
self.eplen_buf.append(eplen)
|
|
self.epcount += 1
|
|
self.eprets[i] = 0
|
|
self.eplens[i] = 0
|
|
if self.results_writer:
|
|
self.results_writer.write_row(epinfo)
|
|
newinfos[i] = info
|
|
return obs, rews, dones, newinfos
|