Jax environment return jax data rather than numpy data (#817)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2024-04-05 18:21:10 +02:00
committed by GitHub
parent f0202ae350
commit d43037920f
12 changed files with 48 additions and 81 deletions

View File

@@ -6,14 +6,12 @@ from typing import Any
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
class FunctionalJaxEnv(gym.Env):
@@ -32,7 +30,8 @@ class FunctionalJaxEnv(gym.Env):
):
"""Initialize the environment from a FuncEnv."""
if metadata is None:
metadata = {"render_mode": []}
# metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
metadata = {"render_mode": [], "jax": True}
self.func_env = func_env
@@ -45,8 +44,6 @@ class FunctionalJaxEnv(gym.Env):
self.spec = spec
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
@@ -69,20 +66,10 @@ class FunctionalJaxEnv(gym.Env):
obs = self.func_env.observation(self.state)
info = self.func_env.state_info(self.state)
obs = jax_to_numpy(obs)
return obs, info
def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg
rng, self.rng = jrng.split(self.rng)
next_state = self.func_env.transition(self.state, action, rng)
@@ -92,8 +79,6 @@ class FunctionalJaxEnv(gym.Env):
info = self.func_env.transition_info(self.state, action, next_state)
self.state = next_state
observation = jax_to_numpy(observation)
return observation, float(reward), bool(terminated), False, info
def render(self):
@@ -153,8 +138,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
@@ -183,20 +166,10 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
obs = jax_to_numpy(obs)
return obs, info
def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
assert self.action_space.contains(
action
), f"{action!r} ({type(action)}) invalid"
self.steps += 1
rng, self.rng = jrng.split(self.rng)
@@ -232,12 +205,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.autoreset_envs = done
observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation)
reward = jax_to_numpy(reward)
terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)
self.state = next_state