2022-08-30 19:41:59 +05:30
|
|
|
"""Contains methods for step compatibility, from old-to-new and new-to-old API."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-03-14 16:31:13 +01:00
|
|
|
from typing import SupportsFloat, Tuple, Union
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium.core import ObsType
|
2022-07-10 02:18:06 +05:30
|
|
|
|
2022-12-04 22:24:02 +08:00
|
|
|
|
2022-08-30 19:41:59 +05:30
|
|
|
DoneStepType = Tuple[
|
2022-07-10 02:18:06 +05:30
|
|
|
Union[ObsType, np.ndarray],
|
2023-03-14 16:31:13 +01:00
|
|
|
Union[SupportsFloat, np.ndarray],
|
2022-07-10 02:18:06 +05:30
|
|
|
Union[bool, np.ndarray],
|
|
|
|
Union[dict, list],
|
|
|
|
]
|
|
|
|
|
2022-08-30 19:41:59 +05:30
|
|
|
TerminatedTruncatedStepType = Tuple[
|
2022-07-10 02:18:06 +05:30
|
|
|
Union[ObsType, np.ndarray],
|
2023-03-14 16:31:13 +01:00
|
|
|
Union[SupportsFloat, np.ndarray],
|
2022-07-10 02:18:06 +05:30
|
|
|
Union[bool, np.ndarray],
|
|
|
|
Union[bool, np.ndarray],
|
|
|
|
Union[dict, list],
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2022-08-30 19:41:59 +05:30
|
|
|
def convert_to_terminated_truncated_step_api(
|
2023-11-07 13:27:25 +00:00
|
|
|
step_returns: DoneStepType | TerminatedTruncatedStepType, is_vector_env=False
|
2022-08-30 19:41:59 +05:30
|
|
|
) -> TerminatedTruncatedStepType:
|
2022-07-10 02:18:06 +05:30
|
|
|
"""Function to transform step returns to new step API irrespective of input API.
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
.. py:currentmodule:: gymnasium.Env
|
|
|
|
|
2022-07-10 02:18:06 +05:30
|
|
|
Args:
|
2023-11-07 13:27:25 +00:00
|
|
|
step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
|
|
|
is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
|
2022-07-10 02:18:06 +05:30
|
|
|
"""
|
|
|
|
if len(step_returns) == 5:
|
|
|
|
return step_returns
|
|
|
|
else:
|
|
|
|
assert len(step_returns) == 4
|
|
|
|
observations, rewards, dones, infos = step_returns
|
|
|
|
|
2022-08-18 15:25:46 +01:00
|
|
|
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
|
|
|
|
if is_vector_env is False:
|
|
|
|
truncated = infos.pop("TimeLimit.truncated", False)
|
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
dones and not truncated,
|
|
|
|
dones and truncated,
|
|
|
|
infos,
|
|
|
|
)
|
|
|
|
elif isinstance(infos, list):
|
|
|
|
truncated = np.array(
|
|
|
|
[info.pop("TimeLimit.truncated", False) for info in infos]
|
|
|
|
)
|
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
np.logical_and(dones, np.logical_not(truncated)),
|
|
|
|
np.logical_and(dones, truncated),
|
|
|
|
infos,
|
|
|
|
)
|
|
|
|
elif isinstance(infos, dict):
|
|
|
|
num_envs = len(dones)
|
|
|
|
truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool))
|
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
np.logical_and(dones, np.logical_not(truncated)),
|
|
|
|
np.logical_and(dones, truncated),
|
|
|
|
infos,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise TypeError(
|
|
|
|
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
|
|
|
|
)
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
|
2022-08-30 19:41:59 +05:30
|
|
|
def convert_to_done_step_api(
|
2023-11-07 13:27:25 +00:00
|
|
|
step_returns: TerminatedTruncatedStepType | DoneStepType,
|
2022-08-30 19:41:59 +05:30
|
|
|
is_vector_env: bool = False,
|
|
|
|
) -> DoneStepType:
|
2022-07-10 02:18:06 +05:30
|
|
|
"""Function to transform step returns to old step API irrespective of input API.
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
.. py:currentmodule:: gymnasium.Env
|
|
|
|
|
2022-07-10 02:18:06 +05:30
|
|
|
Args:
|
2023-11-07 13:27:25 +00:00
|
|
|
step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
|
|
|
is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
|
2022-07-10 02:18:06 +05:30
|
|
|
"""
|
|
|
|
if len(step_returns) == 4:
|
|
|
|
return step_returns
|
|
|
|
else:
|
|
|
|
assert len(step_returns) == 5
|
2022-08-18 15:25:46 +01:00
|
|
|
observations, rewards, terminated, truncated, infos = step_returns
|
|
|
|
|
|
|
|
# Cases to handle - info single env / info vector env (list) / info vector env (dict)
|
|
|
|
if is_vector_env is False:
|
|
|
|
if truncated or terminated:
|
|
|
|
infos["TimeLimit.truncated"] = truncated and not terminated
|
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
terminated or truncated,
|
|
|
|
infos,
|
|
|
|
)
|
|
|
|
elif isinstance(infos, list):
|
|
|
|
for info, env_truncated, env_terminated in zip(
|
|
|
|
infos, truncated, terminated
|
|
|
|
):
|
|
|
|
if env_truncated or env_terminated:
|
|
|
|
info["TimeLimit.truncated"] = env_truncated and not env_terminated
|
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
np.logical_or(terminated, truncated),
|
|
|
|
infos,
|
|
|
|
)
|
|
|
|
elif isinstance(infos, dict):
|
|
|
|
if np.logical_or(np.any(truncated), np.any(terminated)):
|
|
|
|
infos["TimeLimit.truncated"] = np.logical_and(
|
|
|
|
truncated, np.logical_not(terminated)
|
|
|
|
)
|
|
|
|
return (
|
|
|
|
observations,
|
|
|
|
rewards,
|
|
|
|
np.logical_or(terminated, truncated),
|
|
|
|
infos,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise TypeError(
|
|
|
|
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}"
|
|
|
|
)
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
|
|
|
|
def step_api_compatibility(
|
2023-11-07 13:27:25 +00:00
|
|
|
step_returns: TerminatedTruncatedStepType | DoneStepType,
|
2022-08-30 19:41:59 +05:30
|
|
|
output_truncation_bool: bool = True,
|
2022-07-10 02:18:06 +05:30
|
|
|
is_vector_env: bool = False,
|
2023-11-07 13:27:25 +00:00
|
|
|
) -> TerminatedTruncatedStepType | DoneStepType:
|
|
|
|
"""Function to transform step returns to the API specified by ``output_truncation_bool``.
|
2022-07-10 02:18:06 +05:30
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
.. py:currentmodule:: gymnasium.Env
|
|
|
|
|
|
|
|
Done (old) step API refers to :meth:`step` method returning ``(observation, reward, done, info)``
|
|
|
|
Terminated Truncated (new) step API refers to :meth:`step` method returning ``(observation, reward, terminated, truncated, info)``
|
2022-07-10 02:18:06 +05:30
|
|
|
(Refer to docs for details on the API change)
|
|
|
|
|
|
|
|
Args:
|
2023-11-07 13:27:25 +00:00
|
|
|
step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
|
|
|
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (``True`` by default)
|
|
|
|
is_vector_env (bool): Whether the ``step_returns`` are from a vector environment
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
Returns:
|
2023-11-07 13:27:25 +00:00
|
|
|
step_returns (tuple): Depending on ``output_truncation_bool``, it can return ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)``
|
2022-07-10 02:18:06 +05:30
|
|
|
|
2023-01-23 11:30:00 +01:00
|
|
|
Example:
|
2023-11-07 13:27:25 +00:00
|
|
|
This function can be used to ensure compatibility in step interfaces with conflicting API. E.g. if env is written in old API,
|
|
|
|
wrapper is written in new API, and the final step output is desired to be in old API.
|
2022-07-10 02:18:06 +05:30
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
>>> import gymnasium as gym
|
2023-01-20 14:28:09 +01:00
|
|
|
>>> env = gym.make("CartPole-v0")
|
2023-11-07 13:27:25 +00:00
|
|
|
>>> _, _ = env.reset()
|
|
|
|
>>> obs, reward, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False)
|
|
|
|
>>> obs, reward, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True)
|
2023-01-20 14:28:09 +01:00
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
>>> vec_env = gym.make_vec("CartPole-v0", vectorization_mode="sync")
|
|
|
|
>>> _, _ = vec_env.reset()
|
2023-01-20 14:28:09 +01:00
|
|
|
>>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False)
|
2023-11-07 13:27:25 +00:00
|
|
|
>>> obs, rewards, terminations, truncations, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True)
|
|
|
|
|
2022-07-10 02:18:06 +05:30
|
|
|
"""
|
2022-08-30 19:41:59 +05:30
|
|
|
if output_truncation_bool:
|
|
|
|
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
|
2022-07-10 02:18:06 +05:30
|
|
|
else:
|
2022-08-30 19:41:59 +05:30
|
|
|
return convert_to_done_step_api(step_returns, is_vector_env)
|