Compare commits

...

1 Commits

Author SHA1 Message Date
Jonathan Raiman
401a89e515 add tuple pdtype 2019-05-23 15:43:48 -07:00

View File

@@ -275,6 +275,133 @@ class BernoulliPd(Pd):
def fromflat(cls, flat):
return cls(flat)
def _np_cast(x, dtype):
"""Numpy cast, equivalent to tf.cast"""
return x.astype(dtype)
def decode_tuple_sample(pdtypes, x):
"""
Cast and convert a sample from its dense concatenated state back to constituent parts.
Arguments
---------
:param pdtypes: list<PdType>, a TuplePdType's child PdTypes.
:param x: np.ndarray or tf.Tensor.
Shape is [..., sum(pdtype.sample_shape for pdtype in pdtypes)]
:return output, list<np.ndarray> or list<tf.Tensor>, the split and correctly casted
policy samples.
"""
if isinstance(x, np.ndarray):
cast_fn = _np_cast
numpy_casting = True
else:
cast_fn = tf.cast
numpy_casting = False
so_far = 0
xs = []
for pdtype in pdtypes:
sample_size = pdtype.sample_shape()[0] if len(pdtype.sample_shape()) > 0 else 1
if len(pdtype.sample_shape()) == 0:
slided_x = x[..., so_far]
else:
slided_x = x[..., so_far:so_far + sample_size]
desired_dtype = pdtype.sample_dtype()
if numpy_casting:
desired_dtype = desired_dtype.as_numpy_dtype
if desired_dtype != x:
slided_x = cast_fn(slided_x, desired_dtype)
xs.append(slided_x)
so_far += sample_size
return xs
class TuplePd(Pd):
def __init__(self, sample_dtype, pdtypes, logits):
self.pdtypes = pdtypes
self.sample_dtype = sample_dtype
self.pds = []
so_far = 0
for pdtype in self.pdtypes:
param_shape = pdtype.param_shape()[0]
self.pds.append(pdtype.pdfromflat(logits[..., so_far:so_far + param_shape]))
so_far += param_shape
def flatparam(self):
return tf.concat([pd.flatparam() for pd in self.pds], axis=-1)
def mode(self):
return self.tuple_sample_concat([pd.mode() for pd in self.pds])
def tuple_sample_concat(self, samples):
out = []
for sample, pdtype in zip(samples, self.pdtypes):
if len(pdtype.sample_shape()) == 0:
sample = tf.expand_dims(sample, axis=-1)
if sample.dtype != self.sample_dtype:
sample = tf.cast(sample, self.sample_dtype)
out.append(sample)
return tf.concat(out, axis=-1)
def sample(self):
return self.tuple_sample_concat([pd.sample() for pd in self.pds])
def neglogp(self, x):
return tf.add_n([pd.neglogp(xi) for pd, xi in zip(self.pds, decode_tuple_sample(self.pdtypes, x))])
def entropy(self):
return tf.add_n([pd.entropy() for pd in self.pds])
def _dtype_promotion(old, new):
"""
Find the highest precision common ground between two tensorflow datatypes.
if old is None, it is ignored.
"""
if old is None or (new.is_floating and old.is_integer):
return new
if old.is_floating and old.is_integer:
return old
if (old.is_floating and new.is_floating) or (new.is_integer and new.is_integer):
# take the largest type (e.g. float64 over float32)
return old if old.size > new.size else new
raise ValueError("No idea how to promote {} and {}.".format(old, new))
class TuplePdType(PdType):
def __init__(self, space):
self.internal_pdtypes = [make_pdtype(space) for space in space.spaces]
def decode_sample(self, x):
return decode_tuple_sample(self.internal_pdtypes, x)
def pdclass(self):
return TuplePd
def pdfromflat(self, flat):
return TuplePd(self.sample_dtype(), self.internal_pdtypes, flat)
def param_shape(self):
return [sum([pdtype.param_shape()[0]
for pdtype in self.internal_pdtypes])]
def sample_shape(self):
return [sum([pdtype.sample_shape()[0] if len(pdtype.sample_shape()) > 0 else 1
for pdtype in self.internal_pdtypes])]
def sample_dtype(self):
dtype = None
for pdtype in self.internal_pdtypes:
dtype = _dtype_promotion(dtype, pdtype.sample_dtype())
return dtype
def make_pdtype(ac_space):
from gym import spaces
if isinstance(ac_space, spaces.Box):
@@ -286,9 +413,12 @@ def make_pdtype(ac_space):
return MultiCategoricalPdType(ac_space.nvec)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
elif isinstance(ac_space, spaces.Tuple):
return TuplePdType(ac_space)
else:
raise NotImplementedError
def shape_el(v, i):
maybe = v.get_shape()[i]
if maybe is not None: