mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
Jax environment return jax data rather than numpy data (#817)
Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f0202ae350
commit
d43037920f
@@ -6,14 +6,12 @@ from typing import Any
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.vector.utils import batch_space
|
||||
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
|
||||
|
||||
|
||||
class FunctionalJaxEnv(gym.Env):
|
||||
@@ -32,7 +30,8 @@ class FunctionalJaxEnv(gym.Env):
|
||||
):
|
||||
"""Initialize the environment from a FuncEnv."""
|
||||
if metadata is None:
|
||||
metadata = {"render_mode": []}
|
||||
# metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
|
||||
metadata = {"render_mode": [], "jax": True}
|
||||
|
||||
self.func_env = func_env
|
||||
|
||||
@@ -45,8 +44,6 @@ class FunctionalJaxEnv(gym.Env):
|
||||
|
||||
self.spec = spec
|
||||
|
||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||
|
||||
if self.render_mode == "rgb_array":
|
||||
self.render_state = self.func_env.render_init()
|
||||
else:
|
||||
@@ -69,20 +66,10 @@ class FunctionalJaxEnv(gym.Env):
|
||||
obs = self.func_env.observation(self.state)
|
||||
info = self.func_env.state_info(self.state)
|
||||
|
||||
obs = jax_to_numpy(obs)
|
||||
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment using the action."""
|
||||
if self._is_box_action_space:
|
||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
else: # Discrete
|
||||
# For now we assume jax envs don't use complex spaces
|
||||
err_msg = f"{action!r} ({type(action)}) invalid"
|
||||
assert self.action_space.contains(action), err_msg
|
||||
|
||||
rng, self.rng = jrng.split(self.rng)
|
||||
|
||||
next_state = self.func_env.transition(self.state, action, rng)
|
||||
@@ -92,8 +79,6 @@ class FunctionalJaxEnv(gym.Env):
|
||||
info = self.func_env.transition_info(self.state, action, next_state)
|
||||
self.state = next_state
|
||||
|
||||
observation = jax_to_numpy(observation)
|
||||
|
||||
return observation, float(reward), bool(terminated), False, info
|
||||
|
||||
def render(self):
|
||||
@@ -153,8 +138,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
|
||||
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
|
||||
|
||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||
|
||||
if self.render_mode == "rgb_array":
|
||||
self.render_state = self.func_env.render_init()
|
||||
else:
|
||||
@@ -183,20 +166,10 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
|
||||
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
|
||||
|
||||
obs = jax_to_numpy(obs)
|
||||
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment using the action."""
|
||||
if self._is_box_action_space:
|
||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
else: # Discrete
|
||||
# For now we assume jax envs don't use complex spaces
|
||||
assert self.action_space.contains(
|
||||
action
|
||||
), f"{action!r} ({type(action)}) invalid"
|
||||
self.steps += 1
|
||||
|
||||
rng, self.rng = jrng.split(self.rng)
|
||||
@@ -232,12 +205,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
self.autoreset_envs = done
|
||||
|
||||
observation = self.func_env.observation(next_state)
|
||||
observation = jax_to_numpy(observation)
|
||||
|
||||
reward = jax_to_numpy(reward)
|
||||
|
||||
terminated = jax_to_numpy(terminated)
|
||||
truncated = jax_to_numpy(truncated)
|
||||
|
||||
self.state = next_state
|
||||
|
||||
|
Reference in New Issue
Block a user