mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
Flattened space/point dtype mismatch (#2070)
* add test showing mismatch in flattened space dtype and flattened point dtype * fix mismatch in flattened space dtype and flattened point dtype * fix typo * enhance test to detect when flattened dtype is incorrect * fix incorrect flattened dtype * remove inaccurate comment * change flatten to always use space.dtype * added testing for unflattened dtypes * fix unflatten dtypes * swtich flatten_space to use space.dtype for hardcoded space dtypes * fix failure in python 3.5
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gym.spaces import utils
|
||||
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
|
||||
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["space", "flatdim"], [
|
||||
@@ -118,3 +118,55 @@ def compare_nested(left, right):
|
||||
return res
|
||||
else:
|
||||
return left == right
|
||||
|
||||
'''
|
||||
Expecteded flattened types are based off:
|
||||
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
|
||||
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
|
||||
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
|
||||
internally by numpy when np.concatenate is called.
|
||||
'''
|
||||
@pytest.mark.parametrize(["original_space", "expected_flattened_dtype"], [
|
||||
(Discrete(3), np.int64),
|
||||
(Box(low=0., high=np.inf, shape=(2, 2)), np.float32),
|
||||
(Box(low=0., high=np.inf, shape=(2, 2), dtype=np.float16), np.float16),
|
||||
(Tuple([Discrete(5), Discrete(10)]), np.int64),
|
||||
(Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]), np.float64),
|
||||
(Tuple((Discrete(5), Discrete(2), Discrete(2))), np.int64),
|
||||
(MultiDiscrete([2, 2, 100]), np.int64),
|
||||
(MultiBinary(10), np.int8),
|
||||
(Dict({"position": Discrete(5),
|
||||
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16)}), np.float64),
|
||||
])
|
||||
def test_dtypes(original_space, expected_flattened_dtype):
|
||||
flattened_space = utils.flatten_space(original_space)
|
||||
|
||||
original_sample = original_space.sample()
|
||||
flattened_sample = utils.flatten(original_space, original_sample)
|
||||
unflattened_sample = utils.unflatten(original_space, flattened_sample)
|
||||
|
||||
assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
|
||||
assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " \
|
||||
"{}".format(expected_flattened_dtype)
|
||||
|
||||
assert flattened_sample.dtype == flattened_space.dtype, "Expected flattened_space's dtype to equal " \
|
||||
"flattened_sample's dtype "
|
||||
|
||||
compare_sample_types(original_space, original_sample, unflattened_sample)
|
||||
|
||||
|
||||
def compare_sample_types(original_space, original_sample, unflattened_sample):
|
||||
if isinstance(original_space, Discrete):
|
||||
assert isinstance(unflattened_sample, int), "Expected unflattened_sample to be an int. unflattened_sample: " \
|
||||
"{} original_sample: {}".format(unflattened_sample, original_sample)
|
||||
elif isinstance(original_space, Tuple):
|
||||
for index in range(len(original_space)):
|
||||
compare_sample_types(original_space.spaces[index], original_sample[index], unflattened_sample[index])
|
||||
elif isinstance(original_space, Dict):
|
||||
for key, space in original_space.spaces.items():
|
||||
compare_sample_types(space, original_sample[key], unflattened_sample[key])
|
||||
else:
|
||||
assert unflattened_sample.dtype == original_sample.dtype, "Expected unflattened_sample's dtype to equal " \
|
||||
"original_sample's dtype. unflattened_sample: " \
|
||||
"{} original_sample: {}".format(unflattened_sample,
|
||||
original_sample)
|
||||
|
@@ -43,10 +43,10 @@ def flatten(space, x):
|
||||
``gym.spaces``.
|
||||
"""
|
||||
if isinstance(space, Box):
|
||||
return np.asarray(x, dtype=np.float32).flatten()
|
||||
return np.asarray(x, dtype=space.dtype).flatten()
|
||||
elif isinstance(space, Discrete):
|
||||
onehot = np.zeros(space.n, dtype=np.float32)
|
||||
onehot[x] = 1.0
|
||||
onehot = np.zeros(space.n, dtype=space.dtype)
|
||||
onehot[x] = 1
|
||||
return onehot
|
||||
elif isinstance(space, Tuple):
|
||||
return np.concatenate(
|
||||
@@ -55,9 +55,9 @@ def flatten(space, x):
|
||||
return np.concatenate(
|
||||
[flatten(s, x[key]) for key, s in space.spaces.items()])
|
||||
elif isinstance(space, MultiBinary):
|
||||
return np.asarray(x).flatten()
|
||||
return np.asarray(x, dtype=space.dtype).flatten()
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
return np.asarray(x).flatten()
|
||||
return np.asarray(x, dtype=space.dtype).flatten()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -73,7 +73,7 @@ def unflatten(space, x):
|
||||
defined in ``gym.spaces``.
|
||||
"""
|
||||
if isinstance(space, Box):
|
||||
return np.asarray(x, dtype=np.float32).reshape(space.shape)
|
||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||
elif isinstance(space, Discrete):
|
||||
return int(np.nonzero(x)[0][0])
|
||||
elif isinstance(space, Tuple):
|
||||
@@ -94,9 +94,9 @@ def unflatten(space, x):
|
||||
]
|
||||
return OrderedDict(list_unflattened)
|
||||
elif isinstance(space, MultiBinary):
|
||||
return np.asarray(x).reshape(space.shape)
|
||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
return np.asarray(x).reshape(space.shape)
|
||||
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -140,26 +140,33 @@ def flatten_space(space):
|
||||
True
|
||||
"""
|
||||
if isinstance(space, Box):
|
||||
return Box(space.low.flatten(), space.high.flatten())
|
||||
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
|
||||
if isinstance(space, Discrete):
|
||||
return Box(low=0, high=1, shape=(space.n, ))
|
||||
return Box(low=0, high=1, shape=(space.n, ), dtype=space.dtype)
|
||||
if isinstance(space, Tuple):
|
||||
space = [flatten_space(s) for s in space.spaces]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space]),
|
||||
high=np.concatenate([s.high for s in space]),
|
||||
dtype=np.result_type(*[s.dtype for s in space])
|
||||
)
|
||||
if isinstance(space, Dict):
|
||||
space = [flatten_space(s) for s in space.spaces.values()]
|
||||
return Box(
|
||||
low=np.concatenate([s.low for s in space]),
|
||||
high=np.concatenate([s.high for s in space]),
|
||||
dtype=np.result_type(*[s.dtype for s in space])
|
||||
)
|
||||
if isinstance(space, MultiBinary):
|
||||
return Box(low=0, high=1, shape=(space.n, ))
|
||||
return Box(low=0,
|
||||
high=1,
|
||||
shape=(space.n, ),
|
||||
dtype=space.dtype
|
||||
)
|
||||
if isinstance(space, MultiDiscrete):
|
||||
return Box(
|
||||
low=np.zeros_like(space.nvec),
|
||||
high=space.nvec,
|
||||
dtype=space.dtype
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user