Support only new step API (while retaining compatibility functions) (#3019)

This commit is contained in:
Arjun KG
2022-08-30 19:41:59 +05:30
committed by GitHub
parent 884ba08f19
commit 54b406b799
58 changed files with 378 additions and 559 deletions

View File

@@ -17,7 +17,6 @@ from gym.error import (
CustomSpaceError,
NoAsyncCallError,
)
from gym.utils.step_api_compatibility import step_api_compatibility
from gym.vector.utils import (
CloudpickleWrapper,
clear_mpi_env_vars,
@@ -67,7 +66,6 @@ class AsyncVectorEnv(VectorEnv):
context: Optional[str] = None,
daemon: bool = True,
worker: Optional[callable] = None,
new_step_api: bool = False,
):
"""Vectorized environment that runs multiple environments in parallel.
@@ -87,7 +85,6 @@ class AsyncVectorEnv(VectorEnv):
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
@@ -115,7 +112,6 @@ class AsyncVectorEnv(VectorEnv):
num_envs=len(env_fns),
observation_space=observation_space,
action_space=action_space,
new_step_api=new_step_api,
)
if self.shared_memory:
@@ -291,14 +287,14 @@ class AsyncVectorEnv(VectorEnv):
def step_wait(
self, timeout: Optional[Union[int, float]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
Args:
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
Returns:
The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api
The batched environment step information, (obs, reward, terminated, truncated, info)
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
@@ -322,7 +318,7 @@ class AsyncVectorEnv(VectorEnv):
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, terminated, truncated, info = step_api_compatibility(result, True)
obs, rew, terminated, truncated, info = result
successes.append(success)
observations_list.append(obs)
@@ -341,16 +337,12 @@ class AsyncVectorEnv(VectorEnv):
self.observations,
)
return step_api_compatibility(
(
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
),
self.new_step_api,
True,
return (
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
)
def call_async(self, name: str, *args, **kwargs):
@@ -572,7 +564,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
terminated,
truncated,
info,
) = step_api_compatibility(env.step(data), True)
) = env.step(data)
if terminated or truncated:
old_observation = observation
observation, info = env.reset()
@@ -642,7 +634,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
terminated,
truncated,
info,
) = step_api_compatibility(env.step(data), True)
) = env.step(data)
if terminated or truncated:
old_observation = observation
observation, info = env.reset()