diff --git a/gym/spaces/tests/test_utils.py b/gym/spaces/tests/test_utils.py index 69e60ee70..130f3af04 100644 --- a/gym/spaces/tests/test_utils.py +++ b/gym/spaces/tests/test_utils.py @@ -1,9 +1,9 @@ from collections import OrderedDict + import numpy as np import pytest -from gym.spaces import utils -from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict +from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, utils @pytest.mark.parametrize(["space", "flatdim"], [ @@ -118,3 +118,55 @@ def compare_nested(left, right): return res else: return left == right + +''' +Expecteded flattened types are based off: +1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8) +2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type) +3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined + internally by numpy when np.concatenate is called. +''' +@pytest.mark.parametrize(["original_space", "expected_flattened_dtype"], [ + (Discrete(3), np.int64), + (Box(low=0., high=np.inf, shape=(2, 2)), np.float32), + (Box(low=0., high=np.inf, shape=(2, 2), dtype=np.float16), np.float16), + (Tuple([Discrete(5), Discrete(10)]), np.int64), + (Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]), np.float64), + (Tuple((Discrete(5), Discrete(2), Discrete(2))), np.int64), + (MultiDiscrete([2, 2, 100]), np.int64), + (MultiBinary(10), np.int8), + (Dict({"position": Discrete(5), + "velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16)}), np.float64), +]) +def test_dtypes(original_space, expected_flattened_dtype): + flattened_space = utils.flatten_space(original_space) + + original_sample = original_space.sample() + flattened_sample = utils.flatten(original_space, original_sample) + unflattened_sample = utils.unflatten(original_space, flattened_sample) + + assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample" + assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " \ + "{}".format(expected_flattened_dtype) + + assert flattened_sample.dtype == flattened_space.dtype, "Expected flattened_space's dtype to equal " \ + "flattened_sample's dtype " + + compare_sample_types(original_space, original_sample, unflattened_sample) + + +def compare_sample_types(original_space, original_sample, unflattened_sample): + if isinstance(original_space, Discrete): + assert isinstance(unflattened_sample, int), "Expected unflattened_sample to be an int. unflattened_sample: " \ + "{} original_sample: {}".format(unflattened_sample, original_sample) + elif isinstance(original_space, Tuple): + for index in range(len(original_space)): + compare_sample_types(original_space.spaces[index], original_sample[index], unflattened_sample[index]) + elif isinstance(original_space, Dict): + for key, space in original_space.spaces.items(): + compare_sample_types(space, original_sample[key], unflattened_sample[key]) + else: + assert unflattened_sample.dtype == original_sample.dtype, "Expected unflattened_sample's dtype to equal " \ + "original_sample's dtype. unflattened_sample: " \ + "{} original_sample: {}".format(unflattened_sample, + original_sample) diff --git a/gym/spaces/utils.py b/gym/spaces/utils.py index a63840c33..b61aa2e10 100644 --- a/gym/spaces/utils.py +++ b/gym/spaces/utils.py @@ -43,10 +43,10 @@ def flatten(space, x): ``gym.spaces``. """ if isinstance(space, Box): - return np.asarray(x, dtype=np.float32).flatten() + return np.asarray(x, dtype=space.dtype).flatten() elif isinstance(space, Discrete): - onehot = np.zeros(space.n, dtype=np.float32) - onehot[x] = 1.0 + onehot = np.zeros(space.n, dtype=space.dtype) + onehot[x] = 1 return onehot elif isinstance(space, Tuple): return np.concatenate( @@ -55,9 +55,9 @@ def flatten(space, x): return np.concatenate( [flatten(s, x[key]) for key, s in space.spaces.items()]) elif isinstance(space, MultiBinary): - return np.asarray(x).flatten() + return np.asarray(x, dtype=space.dtype).flatten() elif isinstance(space, MultiDiscrete): - return np.asarray(x).flatten() + return np.asarray(x, dtype=space.dtype).flatten() else: raise NotImplementedError @@ -73,7 +73,7 @@ def unflatten(space, x): defined in ``gym.spaces``. """ if isinstance(space, Box): - return np.asarray(x, dtype=np.float32).reshape(space.shape) + return np.asarray(x, dtype=space.dtype).reshape(space.shape) elif isinstance(space, Discrete): return int(np.nonzero(x)[0][0]) elif isinstance(space, Tuple): @@ -94,9 +94,9 @@ def unflatten(space, x): ] return OrderedDict(list_unflattened) elif isinstance(space, MultiBinary): - return np.asarray(x).reshape(space.shape) + return np.asarray(x, dtype=space.dtype).reshape(space.shape) elif isinstance(space, MultiDiscrete): - return np.asarray(x).reshape(space.shape) + return np.asarray(x, dtype=space.dtype).reshape(space.shape) else: raise NotImplementedError @@ -140,26 +140,33 @@ def flatten_space(space): True """ if isinstance(space, Box): - return Box(space.low.flatten(), space.high.flatten()) + return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype) if isinstance(space, Discrete): - return Box(low=0, high=1, shape=(space.n, )) + return Box(low=0, high=1, shape=(space.n, ), dtype=space.dtype) if isinstance(space, Tuple): space = [flatten_space(s) for s in space.spaces] return Box( low=np.concatenate([s.low for s in space]), high=np.concatenate([s.high for s in space]), + dtype=np.result_type(*[s.dtype for s in space]) ) if isinstance(space, Dict): space = [flatten_space(s) for s in space.spaces.values()] return Box( low=np.concatenate([s.low for s in space]), high=np.concatenate([s.high for s in space]), + dtype=np.result_type(*[s.dtype for s in space]) ) if isinstance(space, MultiBinary): - return Box(low=0, high=1, shape=(space.n, )) + return Box(low=0, + high=1, + shape=(space.n, ), + dtype=space.dtype + ) if isinstance(space, MultiDiscrete): return Box( low=np.zeros_like(space.nvec), high=space.nvec, + dtype=space.dtype ) raise NotImplementedError