Apparently Python3 disallows casting numpy scalars into ints. Work around that.

This commit is contained in:
Jonas Schneider
2016-04-27 18:31:32 -07:00
parent 5065950a09
commit 1fd5af8b4e

View File

@@ -10,7 +10,13 @@ class Discrete(Space):
def sample(self):
return np.random.randint(self.n)
def contains(self, x):
return isinstance(x, int) and x >= 0 and x < self.n
if isinstance(x, int):
as_int = x
elif x.dtype.kind in np.typecodes['AllInteger'] and x.shape == ():
as_int = int(x)
else:
return False
return as_int >= 0 and as_int < self.n
def __repr__(self):
return "Discrete(%d)" % self.n
def __eq__(self, other):