Flattened space/point dtype mismatch (#2070)

* add test showing mismatch in flattened space dtype and flattened point dtype

* fix mismatch in flattened space dtype and flattened point dtype

* fix typo

* enhance test to detect when flattened dtype is incorrect

* fix incorrect flattened dtype

* remove inaccurate comment

* change flatten to always use space.dtype

* added testing for unflattened dtypes

* fix unflatten dtypes

* swtich flatten_space to use space.dtype for hardcoded space dtypes

* fix failure in python 3.5
This commit is contained in:
Melvin Wang
2020-11-06 15:06:29 -05:00
committed by GitHub
parent 28c42b63c8
commit eee9b28882
2 changed files with 72 additions and 13 deletions

View File

@@ -1,9 +1,9 @@
from collections import OrderedDict
import numpy as np
import pytest
from gym.spaces import utils
from gym.spaces import Tuple, Box, Discrete, MultiDiscrete, MultiBinary, Dict
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple, utils
@pytest.mark.parametrize(["space", "flatdim"], [
@@ -118,3 +118,55 @@ def compare_nested(left, right):
return res
else:
return left == right
'''
Expecteded flattened types are based off:
1. The type that the space is hardcoded as(ie. multi_discrete=np.int64, discrete=np.int64, multi_binary=np.int8)
2. The type that the space is instantiated with(ie. box=np.float32 by default unless instantiated with a different type)
3. The smallest type that the composite space(tuple, dict) can be represented as. In flatten, this is determined
internally by numpy when np.concatenate is called.
'''
@pytest.mark.parametrize(["original_space", "expected_flattened_dtype"], [
(Discrete(3), np.int64),
(Box(low=0., high=np.inf, shape=(2, 2)), np.float32),
(Box(low=0., high=np.inf, shape=(2, 2), dtype=np.float16), np.float16),
(Tuple([Discrete(5), Discrete(10)]), np.int64),
(Tuple([Discrete(5), Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float32)]), np.float64),
(Tuple((Discrete(5), Discrete(2), Discrete(2))), np.int64),
(MultiDiscrete([2, 2, 100]), np.int64),
(MultiBinary(10), np.int8),
(Dict({"position": Discrete(5),
"velocity": Box(low=np.array([0, 0]), high=np.array([1, 5]), dtype=np.float16)}), np.float64),
])
def test_dtypes(original_space, expected_flattened_dtype):
flattened_space = utils.flatten_space(original_space)
original_sample = original_space.sample()
flattened_sample = utils.flatten(original_space, original_sample)
unflattened_sample = utils.unflatten(original_space, flattened_sample)
assert flattened_space.contains(flattened_sample), "Expected flattened_space to contain flattened_sample"
assert flattened_space.dtype == expected_flattened_dtype, "Expected flattened_space's dtype to equal " \
"{}".format(expected_flattened_dtype)
assert flattened_sample.dtype == flattened_space.dtype, "Expected flattened_space's dtype to equal " \
"flattened_sample's dtype "
compare_sample_types(original_space, original_sample, unflattened_sample)
def compare_sample_types(original_space, original_sample, unflattened_sample):
if isinstance(original_space, Discrete):
assert isinstance(unflattened_sample, int), "Expected unflattened_sample to be an int. unflattened_sample: " \
"{} original_sample: {}".format(unflattened_sample, original_sample)
elif isinstance(original_space, Tuple):
for index in range(len(original_space)):
compare_sample_types(original_space.spaces[index], original_sample[index], unflattened_sample[index])
elif isinstance(original_space, Dict):
for key, space in original_space.spaces.items():
compare_sample_types(space, original_sample[key], unflattened_sample[key])
else:
assert unflattened_sample.dtype == original_sample.dtype, "Expected unflattened_sample's dtype to equal " \
"original_sample's dtype. unflattened_sample: " \
"{} original_sample: {}".format(unflattened_sample,
original_sample)

View File

@@ -43,10 +43,10 @@ def flatten(space, x):
``gym.spaces``.
"""
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).flatten()
return np.asarray(x, dtype=space.dtype).flatten()
elif isinstance(space, Discrete):
onehot = np.zeros(space.n, dtype=np.float32)
onehot[x] = 1.0
onehot = np.zeros(space.n, dtype=space.dtype)
onehot[x] = 1
return onehot
elif isinstance(space, Tuple):
return np.concatenate(
@@ -55,9 +55,9 @@ def flatten(space, x):
return np.concatenate(
[flatten(s, x[key]) for key, s in space.spaces.items()])
elif isinstance(space, MultiBinary):
return np.asarray(x).flatten()
return np.asarray(x, dtype=space.dtype).flatten()
elif isinstance(space, MultiDiscrete):
return np.asarray(x).flatten()
return np.asarray(x, dtype=space.dtype).flatten()
else:
raise NotImplementedError
@@ -73,7 +73,7 @@ def unflatten(space, x):
defined in ``gym.spaces``.
"""
if isinstance(space, Box):
return np.asarray(x, dtype=np.float32).reshape(space.shape)
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
elif isinstance(space, Discrete):
return int(np.nonzero(x)[0][0])
elif isinstance(space, Tuple):
@@ -94,9 +94,9 @@ def unflatten(space, x):
]
return OrderedDict(list_unflattened)
elif isinstance(space, MultiBinary):
return np.asarray(x).reshape(space.shape)
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
elif isinstance(space, MultiDiscrete):
return np.asarray(x).reshape(space.shape)
return np.asarray(x, dtype=space.dtype).reshape(space.shape)
else:
raise NotImplementedError
@@ -140,26 +140,33 @@ def flatten_space(space):
True
"""
if isinstance(space, Box):
return Box(space.low.flatten(), space.high.flatten())
return Box(space.low.flatten(), space.high.flatten(), dtype=space.dtype)
if isinstance(space, Discrete):
return Box(low=0, high=1, shape=(space.n, ))
return Box(low=0, high=1, shape=(space.n, ), dtype=space.dtype)
if isinstance(space, Tuple):
space = [flatten_space(s) for s in space.spaces]
return Box(
low=np.concatenate([s.low for s in space]),
high=np.concatenate([s.high for s in space]),
dtype=np.result_type(*[s.dtype for s in space])
)
if isinstance(space, Dict):
space = [flatten_space(s) for s in space.spaces.values()]
return Box(
low=np.concatenate([s.low for s in space]),
high=np.concatenate([s.high for s in space]),
dtype=np.result_type(*[s.dtype for s in space])
)
if isinstance(space, MultiBinary):
return Box(low=0, high=1, shape=(space.n, ))
return Box(low=0,
high=1,
shape=(space.n, ),
dtype=space.dtype
)
if isinstance(space, MultiDiscrete):
return Box(
low=np.zeros_like(space.nvec),
high=space.nvec,
dtype=space.dtype
)
raise NotImplementedError