rm unnecessary __contains__ duplicate code (#1147)

`contains` really should not exist when it does exactly what the builtin
magic method `__contains__` was meant for, but that would break backward
compatibility.
This commit is contained in:
Alok Singh
2018-08-28 10:51:28 -07:00
committed by pzhokhov
parent 750063055f
commit 6332d4f113
7 changed files with 3 additions and 13 deletions

View File

@@ -215,6 +215,8 @@ class Space(object):
""" """
raise NotImplementedError raise NotImplementedError
__contains__ = contains
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""
# By default, assume identity is JSONable # By default, assume identity is JSONable

View File

@@ -38,8 +38,6 @@ class Box(gym.Space):
def contains(self, x): def contains(self, x):
return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all() return x.shape == self.shape and (x >= self.low).all() and (x <= self.high).all()
__contains__ = contains
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
return np.array(sample_n).tolist() return np.array(sample_n).tolist()
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n):

View File

@@ -51,8 +51,6 @@ class Dict(gym.Space):
return False return False
return True return True
__contains__ = contains
def __repr__(self): def __repr__(self):
return "Dict(" + ", ". join([k + ":" + str(s) for k, s in self.spaces.items()]) + ")" return "Dict(" + ", ". join([k + ":" + str(s) for k, s in self.spaces.items()]) + ")"

View File

@@ -22,8 +22,6 @@ class Discrete(gym.Space):
return False return False
return as_int >= 0 and as_int < self.n return as_int >= 0 and as_int < self.n
__contains__ = contains
def __repr__(self): def __repr__(self):
return "Discrete(%d)" % self.n return "Discrete(%d)" % self.n
def __eq__(self, other): def __eq__(self, other):

View File

@@ -10,8 +10,6 @@ class MultiBinary(gym.Space):
def contains(self, x): def contains(self, x):
return ((x==0) | (x==1)).all() return ((x==0) | (x==1)).all()
__contains__ = contains
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
return np.array(sample_n).tolist() return np.array(sample_n).tolist()
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n):

View File

@@ -12,9 +12,7 @@ class MultiDiscrete(gym.Space):
return (gym.spaces.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype) return (gym.spaces.np_random.random_sample(self.nvec.shape) * self.nvec).astype(self.dtype)
def contains(self, x): def contains(self, x):
return (0 <= x).all() and (x < self.nvec).all() and x.dtype.kind in 'ui' return (0 <= x).all() and (x < self.nvec).all() and x.dtype.kind in 'ui'
__contains__ = contains
def to_jsonable(self, sample_n): def to_jsonable(self, sample_n):
return [sample.tolist() for sample in sample_n] return [sample.tolist() for sample in sample_n]
def from_jsonable(self, sample_n): def from_jsonable(self, sample_n):

View File

@@ -20,8 +20,6 @@ class Tuple(gym.Space):
return isinstance(x, tuple) and len(x) == len(self.spaces) and all( return isinstance(x, tuple) and len(x) == len(self.spaces) and all(
space.contains(part) for (space,part) in zip(self.spaces,x)) space.contains(part) for (space,part) in zip(self.spaces,x))
__contains__ = contains
def __repr__(self): def __repr__(self):
return "Tuple(" + ", ". join([str(s) for s in self.spaces]) + ")" return "Tuple(" + ", ". join([str(s) for s in self.spaces]) + ")"