Add call, get_attr and set_attr methods to VectorEnv (#1600)

* Add call, get_attr and set_attr methods

* Use f-strings and remove assert

* Allow tuples in set_attr and move docstrings

* Replace CubeCrash by CartPole in tests
This commit is contained in:
Tristan Deleu
2022-01-29 12:32:35 -05:00
committed by GitHub
parent b9e8b6c587
commit 081c5c1e80
5 changed files with 279 additions and 8 deletions

View File

@@ -136,6 +136,30 @@ class SyncVectorEnv(VectorEnv):
infos,
)
def call(self, name, *args, **kwargs):
results = []
for env in self.envs:
function = getattr(env, name)
if callable(function):
results.append(function(*args, **kwargs))
else:
results.append(function)
return tuple(results)
def set_attr(self, name, values):
if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)]
if len(values) != self.num_envs:
raise ValueError(
"Values must be a list or tuple with length equal to the "
f"number of environments. Got `{len(values)}` values for "
f"{self.num_envs} environments."
)
for env, value in zip(self.envs, values):
setattr(env, name, value)
def close_extras(self, **kwargs):
"""Close the environments."""
[env.close() for env in self.envs]