Jax environment return jax data rather than numpy data (#817)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2024-04-05 18:21:10 +02:00
committed by GitHub
parent f0202ae350
commit d43037920f
12 changed files with 48 additions and 81 deletions

View File

@@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
Args:
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return
assert isinstance(env.action_space, spaces.Discrete)
upper_bound = env.action_space.start + env.action_space.n - 1
@@ -102,6 +105,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
Args:
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return
env.reset(seed=42)
assert env.spec is not None