[MuJoCo] add the abiltiy for MujocoEnv.reset() to return info (#540)

This commit is contained in:
Kallinteris Andreas
2023-06-07 12:58:36 +03:00
committed by GitHub
parent 4096c53b6b
commit deb50802fa
2 changed files with 18 additions and 2 deletions

View File

@@ -1,5 +1,5 @@
from os import path
from typing import Optional, Union
from typing import Dict, Optional, Union
import numpy as np
@@ -132,6 +132,9 @@ class BaseMujocoEnv(gym.Env):
raise NotImplementedError
# -----------------------------
def _get_reset_info(self) -> Dict:
"""Function that generates the `info` that is returned during a `reset()`."""
return {}
def reset(
self,
@@ -144,9 +147,11 @@ class BaseMujocoEnv(gym.Env):
self._reset_simulation()
ob = self.reset_model()
info = self._get_reset_info()
if self.render_mode == "human":
self.render()
return ob, {}
return ob, info
def set_state(self, qpos, qvel):
"""

View File

@@ -80,6 +80,9 @@ class PointEnv(MujocoEnv, utils.EzPickle):
return observation
def _get_reset_info(self):
return {"works": True}
CHECK_ENV_IGNORE_WARNINGS = [
f"\x1b[33mWARN: {message}\x1b[0m"
@@ -116,3 +119,11 @@ def test_xml_file():
assert env.unwrapped.data.qpos.size == 9
# note can not test user home path (with '~') because github CI does not have a home folder
def test_reset_info():
"""Verify that the environment returns info at `reset()`"""
env = PointEnv()
_, info = env.reset()
assert info["works"] is True