diff --git a/gymnasium/envs/mujoco/mujoco_env.py b/gymnasium/envs/mujoco/mujoco_env.py index 68faf51ee..c3490bd1f 100644 --- a/gymnasium/envs/mujoco/mujoco_env.py +++ b/gymnasium/envs/mujoco/mujoco_env.py @@ -1,5 +1,5 @@ from os import path -from typing import Optional, Union +from typing import Dict, Optional, Union import numpy as np @@ -132,6 +132,9 @@ class BaseMujocoEnv(gym.Env): raise NotImplementedError # ----------------------------- + def _get_reset_info(self) -> Dict: + """Function that generates the `info` that is returned during a `reset()`.""" + return {} def reset( self, @@ -144,9 +147,11 @@ class BaseMujocoEnv(gym.Env): self._reset_simulation() ob = self.reset_model() + info = self._get_reset_info() + if self.render_mode == "human": self.render() - return ob, {} + return ob, info def set_state(self, qpos, qvel): """ diff --git a/tests/envs/mujoco/test_mujoco_custom_env.py b/tests/envs/mujoco/test_mujoco_custom_env.py index b198e45e1..b0cfd1bbb 100644 --- a/tests/envs/mujoco/test_mujoco_custom_env.py +++ b/tests/envs/mujoco/test_mujoco_custom_env.py @@ -80,6 +80,9 @@ class PointEnv(MujocoEnv, utils.EzPickle): return observation + def _get_reset_info(self): + return {"works": True} + CHECK_ENV_IGNORE_WARNINGS = [ f"\x1b[33mWARN: {message}\x1b[0m" @@ -116,3 +119,11 @@ def test_xml_file(): assert env.unwrapped.data.qpos.size == 9 # note can not test user home path (with '~') because github CI does not have a home folder + + +def test_reset_info(): + """Verify that the environment returns info at `reset()`""" + env = PointEnv() + + _, info = env.reset() + assert info["works"] is True