Add call, get_attr and set_attr methods to VectorEnv (#1600)

* Add call, get_attr and set_attr methods

* Use f-strings and remove assert

* Allow tuples in set_attr and move docstrings

* Replace CubeCrash by CartPole in tests
This commit is contained in:
Tristan Deleu
2022-01-29 12:32:35 -05:00
committed by GitHub
parent b9e8b6c587
commit 081c5c1e80
5 changed files with 279 additions and 8 deletions

View File

@@ -34,6 +34,7 @@ class AsyncState(Enum):
DEFAULT = "default" DEFAULT = "default"
WAITING_RESET = "reset" WAITING_RESET = "reset"
WAITING_STEP = "step" WAITING_STEP = "step"
WAITING_CALL = "call"
class AsyncVectorEnv(VectorEnv): class AsyncVectorEnv(VectorEnv):
@@ -292,7 +293,7 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout): if not self._poll(timeout):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
raise mp.TimeoutError( raise mp.TimeoutError(
f"The call to `reset_wait` has timed out after {timeout} second{'s' if timeout > 1 else ''}." f"The call to `reset_wait` has timed out after {timeout} second(s)."
) )
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -382,7 +383,7 @@ class AsyncVectorEnv(VectorEnv):
if not self._poll(timeout): if not self._poll(timeout):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
raise mp.TimeoutError( raise mp.TimeoutError(
f"The call to `step_wait` has timed out after {timeout} second{'s' if timeout > 1 else ''}." f"The call to `step_wait` has timed out after {timeout} second(s)."
) )
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
@@ -404,6 +405,98 @@ class AsyncVectorEnv(VectorEnv):
infos, infos,
) )
def call_async(self, name, *args, **kwargs):
"""
Parameters
----------
name : string
Name of the method or property to call.
*args
Arguments to apply to the method call.
**kwargs
Keywoard arguments to apply to the method call.
"""
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `call_async` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
self._state.value,
)
for pipe in self.parent_pipes:
pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout=None):
"""
Parameters
----------
timeout : int or float, optional
Number of seconds before the call to `step_wait` times out. If
`None` (default), the call to `step_wait` never times out.
Returns
-------
results : list
List of the results of the individual calls to the method or
property for each environment.
"""
self._assert_is_running()
if self._state != AsyncState.WAITING_CALL:
raise NoAsyncCallError(
"Calling `call_wait` without any prior call to `call_async`.",
AsyncState.WAITING_CALL.value,
)
if not self._poll(timeout):
self._state = AsyncState.DEFAULT
raise mp.TimeoutError(
f"The call to `call_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
return results
def set_attr(self, name, values):
"""
Parameters
----------
name : string
Name of the property to be set in each individual environment.
values : list, tuple, or object
Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
"""
self._assert_is_running()
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
self._state.value,
)
for pipe, value in zip(self.parent_pipes, values):
pipe.send(("_setattr", (name, value)))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def close_extras(self, timeout=None, terminate=False): def close_extras(self, timeout=None, terminate=False):
"""Close the environments & clean up the extra resources """Close the environments & clean up the extra resources
(processes and pipes). (processes and pipes).
@@ -539,6 +632,22 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
elif command == "close": elif command == "close":
pipe.send((None, True)) pipe.send((None, True))
break break
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
raise ValueError(
f"Trying to call function `{name}` with "
f"`_call`. Use `{name}` directly instead."
)
function = getattr(env, name)
if callable(function):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))
elif command == "_setattr":
name, value = data
setattr(env, name, value)
pipe.send((None, True))
elif command == "_check_spaces": elif command == "_check_spaces":
pipe.send( pipe.send(
( (
@@ -548,9 +657,9 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
) )
else: else:
raise RuntimeError( raise RuntimeError(
"Received unknown command `{0}`. Must " f"Received unknown command `{command}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, " "be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_check_spaces`}.".format(command) "`_setattr`, `_check_spaces`}."
) )
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2]) error_queue.put((index,) + sys.exc_info()[:2])
@@ -588,15 +697,31 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
elif command == "close": elif command == "close":
pipe.send((None, True)) pipe.send((None, True))
break break
elif command == "_call":
name, args, kwargs = data
if name in ["reset", "step", "seed", "close"]:
raise ValueError(
f"Trying to call function `{name}` with "
f"`_call`. Use `{name}` directly instead."
)
function = getattr(env, name)
if callable(function):
pipe.send((function(*args, **kwargs), True))
else:
pipe.send((function, True))
elif command == "_setattr":
name, value = data
setattr(env, name, value)
pipe.send((None, True))
elif command == "_check_spaces": elif command == "_check_spaces":
pipe.send( pipe.send(
((data[0] == observation_space, data[1] == env.action_space), True) ((data[0] == observation_space, data[1] == env.action_space), True)
) )
else: else:
raise RuntimeError( raise RuntimeError(
"Received unknown command `{0}`. Must " f"Received unknown command `{command}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, " "be one of {`reset`, `step`, `seed`, `close`, `_call`, "
"`_check_spaces`}.".format(command) "`_setattr`, `_check_spaces`}."
) )
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2]) error_queue.put((index,) + sys.exc_info()[:2])

View File

@@ -136,6 +136,30 @@ class SyncVectorEnv(VectorEnv):
infos, infos,
) )
def call(self, name, *args, **kwargs):
results = []
for env in self.envs:
function = getattr(env, name)
if callable(function):
results.append(function(*args, **kwargs))
else:
results.append(function)
return tuple(results)
def set_attr(self, name, values):
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
for env, value in zip(self.envs, values):
setattr(env, name, value)
def close_extras(self, **kwargs): def close_extras(self, **kwargs):
"""Close the environments.""" """Close the environments."""
[env.close() for env in self.envs] [env.close() for env in self.envs]

View File

@@ -108,6 +108,60 @@ class VectorEnv(gym.Env):
self.step_async(actions) self.step_async(actions)
return self.step_wait() return self.step_wait()
def call_async(self, name, *args, **kwargs):
pass
def call_wait(self, **kwargs):
raise NotImplementedError()
def call(self, name, *args, **kwargs):
"""Call a method, or get a property, from each sub-environment.
Parameters
----------
name : string
Name of the method or property to call.
*args
Arguments to apply to the method call.
**kwargs
Keywoard arguments to apply to the method call.
Returns
-------
results : list
List of the results of the individual calls to the method or
property for each environment.
"""
self.call_async(name, *args, **kwargs)
return self.call_wait()
def get_attr(self, name):
"""Get a property from each sub-environment.
Parameters
----------
name : string
Name of the property to be get from each individual environment.
"""
return self.call(name)
def set_attr(self, name, values):
"""Set a property in each sub-environment.
Parameters
----------
name : string
Name of the property to be set in each individual environment.
values : list, tuple, or object
Values of the property to be set to. If `values` is a list or
tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments.
"""
raise NotImplementedError()
def close_extras(self, **kwargs): def close_extras(self, **kwargs):
r"""Clean up the extra resources e.g. beyond what's in this base class.""" r"""Clean up the extra resources e.g. beyond what's in this base class."""
pass pass

View File

@@ -77,6 +77,41 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
assert dones.size == 8 assert dones.size == 8
@pytest.mark.parametrize("shared_memory", [True, False])
def test_call_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
_ = env.reset()
images = env.call("render", mode="rgb_array")
gravity = env.call("gravity")
finally:
env.close()
assert isinstance(images, tuple)
assert len(images) == 4
for i in range(4):
assert isinstance(images[i], np.ndarray)
assert isinstance(gravity, tuple)
assert len(gravity) == 4
for i in range(4):
assert isinstance(gravity[i], float)
assert gravity[i] == 9.8
@pytest.mark.parametrize("shared_memory", [True, False])
def test_set_attr_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close()
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_copy_async_vector_env(shared_memory): def test_copy_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]

View File

@@ -67,6 +67,39 @@ def test_step_sync_vector_env(use_single_action_space):
assert dones.size == 8 assert dones.size == 8
def test_call_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = SyncVectorEnv(env_fns)
_ = env.reset()
images = env.call("render", mode="rgb_array")
gravity = env.call("gravity")
finally:
env.close()
assert isinstance(images, tuple)
assert len(images) == 4
for i in range(4):
assert isinstance(images[i], np.ndarray)
assert isinstance(gravity, tuple)
assert len(gravity) == 4
for i in range(4):
assert isinstance(gravity[i], float)
assert gravity[i] == 9.8
def test_set_attr_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = SyncVectorEnv(env_fns)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close()
def test_check_spaces_sync_vector_env(): def test_check_spaces_sync_vector_env():
# CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2) # CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2)
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]