mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 12:57:38 +00:00
Add call, get_attr and set_attr methods to VectorEnv (#1600)
* Add call, get_attr and set_attr methods * Use f-strings and remove assert * Allow tuples in set_attr and move docstrings * Replace CubeCrash by CartPole in tests
This commit is contained in:
@@ -34,6 +34,7 @@ class AsyncState(Enum):
|
||||
DEFAULT = "default"
|
||||
WAITING_RESET = "reset"
|
||||
WAITING_STEP = "step"
|
||||
WAITING_CALL = "call"
|
||||
|
||||
|
||||
class AsyncVectorEnv(VectorEnv):
|
||||
@@ -292,7 +293,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
f"The call to `reset_wait` has timed out after {timeout} second{'s' if timeout > 1 else ''}."
|
||||
f"The call to `reset_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
@@ -382,7 +383,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
f"The call to `step_wait` has timed out after {timeout} second{'s' if timeout > 1 else ''}."
|
||||
f"The call to `step_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
@@ -404,6 +405,98 @@ class AsyncVectorEnv(VectorEnv):
|
||||
infos,
|
||||
)
|
||||
|
||||
def call_async(self, name, *args, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the method or property to call.
|
||||
|
||||
*args
|
||||
Arguments to apply to the method call.
|
||||
|
||||
**kwargs
|
||||
Keywoard arguments to apply to the method call.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `call_async` while waiting "
|
||||
f"for a pending call to `{self._state.value}` to complete.",
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(("_call", (name, args, kwargs)))
|
||||
self._state = AsyncState.WAITING_CALL
|
||||
|
||||
def call_wait(self, timeout=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to `step_wait` times out. If
|
||||
`None` (default), the call to `step_wait` never times out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
results : list
|
||||
List of the results of the individual calls to the method or
|
||||
property for each environment.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_CALL:
|
||||
raise NoAsyncCallError(
|
||||
"Calling `call_wait` without any prior call to `call_async`.",
|
||||
AsyncState.WAITING_CALL.value,
|
||||
)
|
||||
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError(
|
||||
f"The call to `call_wait` has timed out after {timeout} second(s)."
|
||||
)
|
||||
|
||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
return results
|
||||
|
||||
def set_attr(self, name, values):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the property to be set in each individual environment.
|
||||
|
||||
values : list, tuple, or object
|
||||
Values of the property to be set to. If `values` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise a single value is set for all environments.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values for _ in range(self.num_envs)]
|
||||
if len(values) != self.num_envs:
|
||||
raise ValueError(
|
||||
"Values must be a list or tuple with length equal to the "
|
||||
f"number of environments. Got `{len(values)}` values for "
|
||||
f"{self.num_envs} environments."
|
||||
)
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
"Calling `set_attr` while waiting "
|
||||
f"for a pending call to `{self._state.value}` to complete.",
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
for pipe, value in zip(self.parent_pipes, values):
|
||||
pipe.send(("_setattr", (name, value)))
|
||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
|
||||
def close_extras(self, timeout=None, terminate=False):
|
||||
"""Close the environments & clean up the extra resources
|
||||
(processes and pipes).
|
||||
@@ -539,6 +632,22 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
elif command == "close":
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == "_call":
|
||||
name, args, kwargs = data
|
||||
if name in ["reset", "step", "seed", "close"]:
|
||||
raise ValueError(
|
||||
f"Trying to call function `{name}` with "
|
||||
f"`_call`. Use `{name}` directly instead."
|
||||
)
|
||||
function = getattr(env, name)
|
||||
if callable(function):
|
||||
pipe.send((function(*args, **kwargs), True))
|
||||
else:
|
||||
pipe.send((function, True))
|
||||
elif command == "_setattr":
|
||||
name, value = data
|
||||
setattr(env, name, value)
|
||||
pipe.send((None, True))
|
||||
elif command == "_check_spaces":
|
||||
pipe.send(
|
||||
(
|
||||
@@ -548,9 +657,9 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Received unknown command `{0}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, "
|
||||
"`_check_spaces`}.".format(command)
|
||||
f"Received unknown command `{command}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
|
||||
"`_setattr`, `_check_spaces`}."
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
@@ -588,15 +697,31 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
elif command == "close":
|
||||
pipe.send((None, True))
|
||||
break
|
||||
elif command == "_call":
|
||||
name, args, kwargs = data
|
||||
if name in ["reset", "step", "seed", "close"]:
|
||||
raise ValueError(
|
||||
f"Trying to call function `{name}` with "
|
||||
f"`_call`. Use `{name}` directly instead."
|
||||
)
|
||||
function = getattr(env, name)
|
||||
if callable(function):
|
||||
pipe.send((function(*args, **kwargs), True))
|
||||
else:
|
||||
pipe.send((function, True))
|
||||
elif command == "_setattr":
|
||||
name, value = data
|
||||
setattr(env, name, value)
|
||||
pipe.send((None, True))
|
||||
elif command == "_check_spaces":
|
||||
pipe.send(
|
||||
((data[0] == observation_space, data[1] == env.action_space), True)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Received unknown command `{0}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, "
|
||||
"`_check_spaces`}.".format(command)
|
||||
f"Received unknown command `{command}`. Must "
|
||||
"be one of {`reset`, `step`, `seed`, `close`, `_call`, "
|
||||
"`_setattr`, `_check_spaces`}."
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
|
@@ -136,6 +136,30 @@ class SyncVectorEnv(VectorEnv):
|
||||
infos,
|
||||
)
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
results = []
|
||||
for env in self.envs:
|
||||
function = getattr(env, name)
|
||||
if callable(function):
|
||||
results.append(function(*args, **kwargs))
|
||||
else:
|
||||
results.append(function)
|
||||
|
||||
return tuple(results)
|
||||
|
||||
def set_attr(self, name, values):
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = [values for _ in range(self.num_envs)]
|
||||
if len(values) != self.num_envs:
|
||||
raise ValueError(
|
||||
"Values must be a list or tuple with length equal to the "
|
||||
f"number of environments. Got `{len(values)}` values for "
|
||||
f"{self.num_envs} environments."
|
||||
)
|
||||
|
||||
for env, value in zip(self.envs, values):
|
||||
setattr(env, name, value)
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
"""Close the environments."""
|
||||
[env.close() for env in self.envs]
|
||||
|
@@ -108,6 +108,60 @@ class VectorEnv(gym.Env):
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def call_async(self, name, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def call_wait(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
"""Call a method, or get a property, from each sub-environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the method or property to call.
|
||||
|
||||
*args
|
||||
Arguments to apply to the method call.
|
||||
|
||||
**kwargs
|
||||
Keywoard arguments to apply to the method call.
|
||||
|
||||
Returns
|
||||
-------
|
||||
results : list
|
||||
List of the results of the individual calls to the method or
|
||||
property for each environment.
|
||||
"""
|
||||
self.call_async(name, *args, **kwargs)
|
||||
return self.call_wait()
|
||||
|
||||
def get_attr(self, name):
|
||||
"""Get a property from each sub-environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the property to be get from each individual environment.
|
||||
"""
|
||||
return self.call(name)
|
||||
|
||||
def set_attr(self, name, values):
|
||||
"""Set a property in each sub-environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : string
|
||||
Name of the property to be set in each individual environment.
|
||||
|
||||
values : list, tuple, or object
|
||||
Values of the property to be set to. If `values` is a list or
|
||||
tuple, then it corresponds to the values for each individual
|
||||
environment, otherwise a single value is set for all environments.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def close_extras(self, **kwargs):
|
||||
r"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||
pass
|
||||
|
@@ -77,6 +77,41 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
||||
assert dones.size == 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_call_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
_ = env.reset()
|
||||
images = env.call("render", mode="rgb_array")
|
||||
gravity = env.call("gravity")
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(images, tuple)
|
||||
assert len(images) == 4
|
||||
for i in range(4):
|
||||
assert isinstance(images[i], np.ndarray)
|
||||
|
||||
assert isinstance(gravity, tuple)
|
||||
assert len(gravity) == 4
|
||||
for i in range(4):
|
||||
assert isinstance(gravity[i], float)
|
||||
assert gravity[i] == 9.8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_set_attr_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
||||
gravity = env.get_attr("gravity")
|
||||
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_copy_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
|
@@ -67,6 +67,39 @@ def test_step_sync_vector_env(use_single_action_space):
|
||||
assert dones.size == 8
|
||||
|
||||
|
||||
def test_call_sync_vector_env():
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
_ = env.reset()
|
||||
images = env.call("render", mode="rgb_array")
|
||||
gravity = env.call("gravity")
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(images, tuple)
|
||||
assert len(images) == 4
|
||||
for i in range(4):
|
||||
assert isinstance(images[i], np.ndarray)
|
||||
|
||||
assert isinstance(gravity, tuple)
|
||||
assert len(gravity) == 4
|
||||
for i in range(4):
|
||||
assert isinstance(gravity[i], float)
|
||||
assert gravity[i] == 9.8
|
||||
|
||||
|
||||
def test_set_attr_sync_vector_env():
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
||||
gravity = env.get_attr("gravity")
|
||||
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
|
||||
def test_check_spaces_sync_vector_env():
|
||||
# CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2)
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
|
Reference in New Issue
Block a user