Files
Gymnasium/gym/wrappers/record_episode_statistics.py
Andrea PIERRÉ e913bc81b8 Improve pre-commit workflow (#2602)
* feat: add `isort` to `pre-commit`

* ci: skip `__init__.py` file for `isort`

* ci: make `isort` mandatory in lint pipeline

* docs: add a section on Git hooks

* ci: check isort diff

* fix: isort from master branch

* docs: add pre-commit badge

* ci: update black + bandit versions

* feat: add PR template

* refactor: PR template

* ci: remove bandit

* docs: add Black badge

* ci: try to remove all `|| true` statements

* ci: remove lint_python job

- Remove `lint_python` CI job
- Move `pyupgrade` job to `pre-commit` workflow

* fix: avoid messing with typing

* docs: add a note on running `pre-cpmmit` manually

* ci: apply `pre-commit` to the whole codebase
2022-03-31 15:50:38 -04:00

61 lines
2.1 KiB
Python

import time
from collections import deque
from typing import Optional
import numpy as np
import gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.t0 = time.perf_counter()
self.episode_count = 0
self.episode_returns = None
self.episode_lengths = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)
def reset(self, **kwargs):
observations = super().reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations
def step(self, action):
observations, rewards, dones, infos = super().step(action)
self.episode_returns += rewards
self.episode_lengths += 1
if not self.is_vector_env:
infos = [infos]
dones = [dones]
else:
infos = list(infos) # Convert infos to mutable type
for i in range(len(dones)):
if dones[i]:
infos[i] = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.perf_counter() - self.t0, 6),
}
infos[i]["episode"] = episode_info
self.return_queue.append(episode_return)
self.length_queue.append(episode_length)
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.is_vector_env:
infos = tuple(infos)
return (
observations,
rewards,
dones if self.is_vector_env else dones[0],
infos if self.is_vector_env else infos[0],
)