2019-07-23 21:34:39 +02:00
|
|
|
try:
|
|
|
|
from collections.abc import Iterable
|
|
|
|
except ImportError:
|
|
|
|
Iterable = (tuple, list)
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
from gym.vector.async_vector_env import AsyncVectorEnv
|
|
|
|
from gym.vector.sync_vector_env import SyncVectorEnv
|
2020-08-14 14:20:56 -07:00
|
|
|
from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
__all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-07-23 21:34:39 +02:00
|
|
|
def make(id, num_envs=1, asynchronous=True, wrappers=None, **kwargs):
|
2019-06-21 17:29:44 -04:00
|
|
|
"""Create a vectorized environment from multiple copies of an environment,
|
|
|
|
from its id
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
id : str
|
|
|
|
The environment ID. This must be a valid ID from the registry.
|
|
|
|
|
|
|
|
num_envs : int
|
2021-07-29 02:26:34 +02:00
|
|
|
Number of copies of the environment.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
asynchronous : bool (default: `True`)
|
2021-07-29 02:26:34 +02:00
|
|
|
If `True`, wraps the environments in an `AsyncVectorEnv` (which uses
|
2019-06-21 17:29:44 -04:00
|
|
|
`multiprocessing` to run the environments in parallel). If `False`,
|
|
|
|
wraps the environments in a `SyncVectorEnv`.
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-07-23 21:34:39 +02:00
|
|
|
wrappers : Callable or Iterable of Callables (default: `None`)
|
2021-07-29 02:26:34 +02:00
|
|
|
If not `None`, then apply the wrappers to each internal
|
|
|
|
environment during creation.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
env : `gym.vector.VectorEnv` instance
|
|
|
|
The vectorized environment.
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> import gym
|
|
|
|
>>> env = gym.vector.make('CartPole-v1', 3)
|
|
|
|
>>> env.reset()
|
|
|
|
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
|
|
|
|
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
|
|
|
|
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
|
|
|
|
dtype=float32)
|
|
|
|
"""
|
|
|
|
from gym.envs import make as make_
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
def _make_env():
|
2019-07-23 21:34:39 +02:00
|
|
|
env = make_(id, **kwargs)
|
|
|
|
if wrappers is not None:
|
|
|
|
if callable(wrappers):
|
|
|
|
env = wrappers(env)
|
2021-07-29 15:39:42 -04:00
|
|
|
elif isinstance(wrappers, Iterable) and all(
|
|
|
|
[callable(w) for w in wrappers]
|
|
|
|
):
|
2019-07-23 21:34:39 +02:00
|
|
|
for wrapper in wrappers:
|
|
|
|
env = wrapper(env)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
return env
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
env_fns = [_make_env for _ in range(num_envs)]
|
|
|
|
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
|