mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-15 03:08:43 +00:00
Jax environment return jax data rather than numpy data (#817)
Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f0202ae350
commit
d43037920f
@@ -1,8 +1,6 @@
|
||||
"""Finds all the specs that we can test with"""
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import logger
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
@@ -55,28 +53,3 @@ gym_testing_env_specs: List[EnvSpec] = [
|
||||
for ep in ["box2d", "classic_control", "toy_text"]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def assert_equals(a, b, prefix=None):
|
||||
"""Assert equality of data structures `a` and `b`.
|
||||
|
||||
Args:
|
||||
a: first data structure
|
||||
b: second data structure
|
||||
prefix: prefix for failed assertion message for types and dicts
|
||||
"""
|
||||
assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}"
|
||||
if isinstance(a, dict):
|
||||
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
|
||||
|
||||
for k in a.keys():
|
||||
v_a = a[k]
|
||||
v_b = b[k]
|
||||
assert_equals(v_a, v_b)
|
||||
elif isinstance(a, np.ndarray):
|
||||
np.testing.assert_array_equal(a, b)
|
||||
elif isinstance(a, tuple):
|
||||
for elem_from_a, elem_from_b in zip(a, b):
|
||||
assert_equals(elem_from_a, elem_from_b)
|
||||
else:
|
||||
assert a == b
|
||||
|
Reference in New Issue
Block a user