2019-11-01 22:27:39 +01:00
|
|
|
import time
|
|
|
|
from collections import deque
|
|
|
|
|
|
|
|
import gym
|
|
|
|
|
|
|
|
|
|
|
|
class RecordEpisodeStatistics(gym.Wrapper):
|
|
|
|
def __init__(self, env, deque_size=100):
|
|
|
|
super(RecordEpisodeStatistics, self).__init__(env)
|
2021-07-29 02:26:34 +02:00
|
|
|
self.t0 = (
|
|
|
|
time.time()
|
|
|
|
) # TODO: use perf_counter when gym removes Python 2 support
|
2019-11-01 22:27:39 +01:00
|
|
|
self.episode_return = 0.0
|
|
|
|
self.episode_length = 0
|
|
|
|
self.return_queue = deque(maxlen=deque_size)
|
|
|
|
self.length_queue = deque(maxlen=deque_size)
|
|
|
|
|
|
|
|
def reset(self, **kwargs):
|
|
|
|
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
|
|
|
|
self.episode_return = 0.0
|
|
|
|
self.episode_length = 0
|
|
|
|
return observation
|
|
|
|
|
|
|
|
def step(self, action):
|
2021-07-29 02:26:34 +02:00
|
|
|
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
|
|
|
|
action
|
|
|
|
)
|
2019-11-01 22:27:39 +01:00
|
|
|
self.episode_return += reward
|
|
|
|
self.episode_length += 1
|
|
|
|
if done:
|
2021-07-29 02:26:34 +02:00
|
|
|
info["episode"] = {
|
|
|
|
"r": self.episode_return,
|
|
|
|
"l": self.episode_length,
|
|
|
|
"t": round(time.time() - self.t0, 6),
|
|
|
|
}
|
2019-11-01 22:27:39 +01:00
|
|
|
self.return_queue.append(self.episode_return)
|
|
|
|
self.length_queue.append(self.episode_length)
|
|
|
|
self.episode_return = 0.0
|
|
|
|
self.episode_length = 0
|
|
|
|
return observation, reward, done, info
|