2022-07-10 02:18:06 +05:30
|
|
|
"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
|
2022-09-16 23:41:27 +01:00
|
|
|
import gymnasium as gym
|
2022-09-08 10:10:07 +01:00
|
|
|
from gymnasium.logger import deprecation
|
|
|
|
from gymnasium.utils.step_api_compatibility import (
|
2022-08-30 19:41:59 +05:30
|
|
|
convert_to_done_step_api,
|
|
|
|
convert_to_terminated_truncated_step_api,
|
|
|
|
)
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
class StepAPICompatibility(gym.Wrapper):
|
2022-07-10 02:18:06 +05:30
|
|
|
r"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
|
|
|
|
|
|
|
Old step API refers to step() method returning (observation, reward, done, info)
|
|
|
|
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
|
|
|
(Refer to docs for details on the API change)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
env (gym.Env): the env to wrap. Can be in old or new API
|
2022-08-30 19:41:59 +05:30
|
|
|
apply_step_compatibility (bool): Apply to convert environment to use new step API that returns two bools. (False by default)
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
Examples:
|
2022-09-16 23:41:27 +01:00
|
|
|
>>> env = gym.make("CartPole-v1")
|
2022-08-30 19:41:59 +05:30
|
|
|
>>> env # wrapper not applied by default, set to new API
|
|
|
|
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
|
2022-09-16 23:41:27 +01:00
|
|
|
>>> env = gym.make("CartPole-v1", apply_api_compatibility=True) # set to old API
|
2022-08-30 19:41:59 +05:30
|
|
|
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
|
|
|
|
>>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
"""
|
|
|
|
|
2022-09-16 23:41:27 +01:00
|
|
|
def __init__(self, env: gym.Env, output_truncation_bool: bool = True):
|
2022-07-10 02:18:06 +05:30
|
|
|
"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
env (gym.Env): the env to wrap. Can be in old or new API
|
2022-08-30 19:41:59 +05:30
|
|
|
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
2022-07-10 02:18:06 +05:30
|
|
|
"""
|
2022-08-30 19:41:59 +05:30
|
|
|
super().__init__(env)
|
|
|
|
self.output_truncation_bool = output_truncation_bool
|
|
|
|
if not self.output_truncation_bool:
|
2022-07-10 02:18:06 +05:30
|
|
|
deprecation(
|
2022-09-08 10:58:14 +01:00
|
|
|
"Initializing environment in (old) done step API which returns one bool instead of two."
|
2022-07-10 02:18:06 +05:30
|
|
|
)
|
|
|
|
|
|
|
|
def step(self, action):
|
2022-08-30 19:41:59 +05:30
|
|
|
"""Steps through the environment, returning 5 or 4 items depending on `apply_step_compatibility`.
|
2022-07-10 02:18:06 +05:30
|
|
|
|
|
|
|
Args:
|
|
|
|
action: action to step through the environment with
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
(observation, reward, terminated, truncated, info) or (observation, reward, done, info)
|
|
|
|
"""
|
|
|
|
step_returns = self.env.step(action)
|
2022-08-30 19:41:59 +05:30
|
|
|
if self.output_truncation_bool:
|
|
|
|
return convert_to_terminated_truncated_step_api(step_returns)
|
2022-07-10 02:18:06 +05:30
|
|
|
else:
|
2022-08-30 19:41:59 +05:30
|
|
|
return convert_to_done_step_api(step_returns)
|