Added singledispatch utility to vector.utils & changed order of space argument. (#2536)

* Fixed ordering of space. Added singledispatch utility.

* Added singledispatch utility to vector.utils & changed order of space argument

* Fixed Error from _BaseGymSpaces

* Minor adjustment for Discrete Spaces

* Fixed Tests/ to reflect changes

* Fixed precommit error - custom namespaces

* Concrete Implementations start with _
This commit is contained in:
Rushiv Arora
2022-01-21 11:28:34 -05:00
committed by GitHub
parent 925823661d
commit fcbff7de12
8 changed files with 139 additions and 136 deletions

View File

@@ -144,7 +144,7 @@ class AsyncVectorEnv(VectorEnv):
self.single_observation_space, n=self.num_envs, ctx=ctx
)
self.observations = read_from_shared_memory(
_obs_buffer, self.single_observation_space, n=self.num_envs
self.single_observation_space, _obs_buffer, n=self.num_envs
)
except CustomSpaceError:
raise ValueError(
@@ -301,7 +301,7 @@ class AsyncVectorEnv(VectorEnv):
if not self.shared_memory:
self.observations = concatenate(
results, self.observations, self.single_observation_space
self.single_observation_space, results, self.observations
)
return deepcopy(self.observations) if self.copy else self.observations
@@ -392,7 +392,9 @@ class AsyncVectorEnv(VectorEnv):
if not self.shared_memory:
self.observations = concatenate(
observations_list, self.observations, self.single_observation_space
self.single_observation_space,
observations_list,
self.observations,
)
return (
@@ -568,7 +570,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
if command == "reset":
observation = env.reset(**data)
write_to_shared_memory(
index, observation, shared_memory, observation_space
observation_space, index, observation, shared_memory
)
pipe.send((None, True))
elif command == "step":
@@ -577,7 +579,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
info["terminal_observation"] = observation
observation = env.reset()
write_to_shared_memory(
index, observation, shared_memory, observation_space
observation_space, index, observation, shared_memory
)
pipe.send(((None, reward, done, info), True))
elif command == "seed":