Jax environment return jax data rather than numpy data (#817)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2024-04-05 18:21:10 +02:00
committed by GitHub
parent f0202ae350
commit d43037920f
12 changed files with 48 additions and 81 deletions

View File

@@ -1,8 +1,6 @@
"""Finds all the specs that we can test with"""
from typing import List, Optional
import numpy as np
import gymnasium as gym
from gymnasium import logger
from gymnasium.envs.registration import EnvSpec
@@ -55,28 +53,3 @@ gym_testing_env_specs: List[EnvSpec] = [
for ep in ["box2d", "classic_control", "toy_text"]
)
]
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
Args:
a: first data structure
b: second data structure
prefix: prefix for failed assertion message for types and dicts
"""
assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}"
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
for k in a.keys():
v_a = a[k]
v_b = b[k]
assert_equals(v_a, v_b)
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
elif isinstance(a, tuple):
for elem_from_a, elem_from_b in zip(a, b):
assert_equals(elem_from_a, elem_from_b)
else:
assert a == b