mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 01:27:29 +00:00
Support only new step API (while retaining compatibility functions) (#3019)
This commit is contained in:
@@ -17,7 +17,6 @@ from gym.error import (
|
||||
CustomSpaceError,
|
||||
NoAsyncCallError,
|
||||
)
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
from gym.vector.utils import (
|
||||
CloudpickleWrapper,
|
||||
clear_mpi_env_vars,
|
||||
@@ -67,7 +66,6 @@ class AsyncVectorEnv(VectorEnv):
|
||||
context: Optional[str] = None,
|
||||
daemon: bool = True,
|
||||
worker: Optional[callable] = None,
|
||||
new_step_api: bool = False,
|
||||
):
|
||||
"""Vectorized environment that runs multiple environments in parallel.
|
||||
|
||||
@@ -87,7 +85,6 @@ class AsyncVectorEnv(VectorEnv):
|
||||
so for some environments you may want to have it set to ``False``.
|
||||
worker: If set, then use that worker in a subprocess instead of a default one.
|
||||
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
||||
new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done
|
||||
|
||||
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
||||
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
||||
@@ -115,7 +112,6 @@ class AsyncVectorEnv(VectorEnv):
|
||||
num_envs=len(env_fns),
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
new_step_api=new_step_api,
|
||||
)
|
||||
|
||||
if self.shared_memory:
|
||||
@@ -291,14 +287,14 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
def step_wait(
|
||||
self, timeout: Optional[Union[int, float]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]:
|
||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||
|
||||
Args:
|
||||
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
|
||||
|
||||
Returns:
|
||||
The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api
|
||||
The batched environment step information, (obs, reward, terminated, truncated, info)
|
||||
|
||||
Raises:
|
||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||
@@ -322,7 +318,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
successes = []
|
||||
for i, pipe in enumerate(self.parent_pipes):
|
||||
result, success = pipe.recv()
|
||||
obs, rew, terminated, truncated, info = step_api_compatibility(result, True)
|
||||
obs, rew, terminated, truncated, info = result
|
||||
|
||||
successes.append(success)
|
||||
observations_list.append(obs)
|
||||
@@ -341,16 +337,12 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.observations,
|
||||
)
|
||||
|
||||
return step_api_compatibility(
|
||||
(
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.array(rewards),
|
||||
np.array(terminateds, dtype=np.bool_),
|
||||
np.array(truncateds, dtype=np.bool_),
|
||||
infos,
|
||||
),
|
||||
self.new_step_api,
|
||||
True,
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.array(rewards),
|
||||
np.array(terminateds, dtype=np.bool_),
|
||||
np.array(truncateds, dtype=np.bool_),
|
||||
infos,
|
||||
)
|
||||
|
||||
def call_async(self, name: str, *args, **kwargs):
|
||||
@@ -572,7 +564,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
terminated,
|
||||
truncated,
|
||||
info,
|
||||
) = step_api_compatibility(env.step(data), True)
|
||||
) = env.step(data)
|
||||
if terminated or truncated:
|
||||
old_observation = observation
|
||||
observation, info = env.reset()
|
||||
@@ -642,7 +634,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
terminated,
|
||||
truncated,
|
||||
info,
|
||||
) = step_api_compatibility(env.step(data), True)
|
||||
) = env.step(data)
|
||||
if terminated or truncated:
|
||||
old_observation = observation
|
||||
observation, info = env.reset()
|
||||
|
Reference in New Issue
Block a user