2019-11-01 22:27:39 +01:00
|
|
|
import time
|
|
|
|
from collections import deque
|
2021-12-08 22:14:15 +01:00
|
|
|
|
2021-08-05 17:06:49 -04:00
|
|
|
import numpy as np
|
2022-03-31 12:50:38 -07:00
|
|
|
|
2019-11-01 22:27:39 +01:00
|
|
|
import gym
|
|
|
|
|
|
|
|
|
|
|
|
class RecordEpisodeStatistics(gym.Wrapper):
|
|
|
|
def __init__(self, env, deque_size=100):
|
2021-11-14 01:53:06 +01:00
|
|
|
super().__init__(env)
|
2021-08-05 17:06:49 -04:00
|
|
|
self.num_envs = getattr(env, "num_envs", 1)
|
2021-09-12 02:03:54 +09:00
|
|
|
self.t0 = time.perf_counter()
|
2021-08-05 17:06:49 -04:00
|
|
|
self.episode_count = 0
|
|
|
|
self.episode_returns = None
|
|
|
|
self.episode_lengths = None
|
2019-11-01 22:27:39 +01:00
|
|
|
self.return_queue = deque(maxlen=deque_size)
|
|
|
|
self.length_queue = deque(maxlen=deque_size)
|
2021-08-18 16:36:40 -04:00
|
|
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
2019-11-01 22:27:39 +01:00
|
|
|
|
2022-01-19 23:28:59 +01:00
|
|
|
def reset(self, **kwargs):
|
|
|
|
observations = super().reset(**kwargs)
|
2021-08-05 17:06:49 -04:00
|
|
|
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
|
|
|
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
|
|
|
return observations
|
2019-11-01 22:27:39 +01:00
|
|
|
|
|
|
|
def step(self, action):
|
2021-11-14 01:53:06 +01:00
|
|
|
observations, rewards, dones, infos = super().step(action)
|
2021-08-05 17:06:49 -04:00
|
|
|
self.episode_returns += rewards
|
|
|
|
self.episode_lengths += 1
|
2021-08-18 16:36:40 -04:00
|
|
|
if not self.is_vector_env:
|
2021-08-05 17:06:49 -04:00
|
|
|
infos = [infos]
|
|
|
|
dones = [dones]
|
2022-01-30 02:44:31 +01:00
|
|
|
else:
|
|
|
|
infos = list(infos) # Convert infos to mutable type
|
2021-08-05 17:06:49 -04:00
|
|
|
for i in range(len(dones)):
|
|
|
|
if dones[i]:
|
|
|
|
infos[i] = infos[i].copy()
|
|
|
|
episode_return = self.episode_returns[i]
|
|
|
|
episode_length = self.episode_lengths[i]
|
|
|
|
episode_info = {
|
|
|
|
"r": episode_return,
|
|
|
|
"l": episode_length,
|
2021-09-12 02:03:54 +09:00
|
|
|
"t": round(time.perf_counter() - self.t0, 6),
|
2021-08-05 17:06:49 -04:00
|
|
|
}
|
|
|
|
infos[i]["episode"] = episode_info
|
|
|
|
self.return_queue.append(episode_return)
|
|
|
|
self.length_queue.append(episode_length)
|
|
|
|
self.episode_count += 1
|
|
|
|
self.episode_returns[i] = 0
|
|
|
|
self.episode_lengths[i] = 0
|
2022-01-30 02:44:31 +01:00
|
|
|
if self.is_vector_env:
|
|
|
|
infos = tuple(infos)
|
2021-08-05 17:06:49 -04:00
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
2021-08-18 16:36:40 -04:00
|
|
|
dones if self.is_vector_env else dones[0],
|
|
|
|
infos if self.is_vector_env else infos[0],
|
2021-08-05 17:06:49 -04:00
|
|
|
)
|