Compare commits
1 Commits
master
...
tuple_pdty
Author | SHA1 | Date | |
---|---|---|---|
|
401a89e515 |
@@ -275,6 +275,133 @@ class BernoulliPd(Pd):
|
||||
def fromflat(cls, flat):
|
||||
return cls(flat)
|
||||
|
||||
|
||||
|
||||
def _np_cast(x, dtype):
|
||||
"""Numpy cast, equivalent to tf.cast"""
|
||||
return x.astype(dtype)
|
||||
|
||||
|
||||
def decode_tuple_sample(pdtypes, x):
|
||||
"""
|
||||
Cast and convert a sample from its dense concatenated state back to constituent parts.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
|
||||
:param pdtypes: list<PdType>, a TuplePdType's child PdTypes.
|
||||
:param x: np.ndarray or tf.Tensor.
|
||||
Shape is [..., sum(pdtype.sample_shape for pdtype in pdtypes)]
|
||||
|
||||
:return output, list<np.ndarray> or list<tf.Tensor>, the split and correctly casted
|
||||
policy samples.
|
||||
"""
|
||||
if isinstance(x, np.ndarray):
|
||||
cast_fn = _np_cast
|
||||
numpy_casting = True
|
||||
else:
|
||||
cast_fn = tf.cast
|
||||
numpy_casting = False
|
||||
|
||||
so_far = 0
|
||||
xs = []
|
||||
for pdtype in pdtypes:
|
||||
sample_size = pdtype.sample_shape()[0] if len(pdtype.sample_shape()) > 0 else 1
|
||||
if len(pdtype.sample_shape()) == 0:
|
||||
slided_x = x[..., so_far]
|
||||
else:
|
||||
slided_x = x[..., so_far:so_far + sample_size]
|
||||
|
||||
desired_dtype = pdtype.sample_dtype()
|
||||
if numpy_casting:
|
||||
desired_dtype = desired_dtype.as_numpy_dtype
|
||||
if desired_dtype != x:
|
||||
slided_x = cast_fn(slided_x, desired_dtype)
|
||||
xs.append(slided_x)
|
||||
so_far += sample_size
|
||||
return xs
|
||||
|
||||
|
||||
class TuplePd(Pd):
|
||||
def __init__(self, sample_dtype, pdtypes, logits):
|
||||
self.pdtypes = pdtypes
|
||||
self.sample_dtype = sample_dtype
|
||||
self.pds = []
|
||||
so_far = 0
|
||||
for pdtype in self.pdtypes:
|
||||
param_shape = pdtype.param_shape()[0]
|
||||
self.pds.append(pdtype.pdfromflat(logits[..., so_far:so_far + param_shape]))
|
||||
so_far += param_shape
|
||||
|
||||
def flatparam(self):
|
||||
return tf.concat([pd.flatparam() for pd in self.pds], axis=-1)
|
||||
|
||||
def mode(self):
|
||||
return self.tuple_sample_concat([pd.mode() for pd in self.pds])
|
||||
|
||||
def tuple_sample_concat(self, samples):
|
||||
out = []
|
||||
for sample, pdtype in zip(samples, self.pdtypes):
|
||||
if len(pdtype.sample_shape()) == 0:
|
||||
sample = tf.expand_dims(sample, axis=-1)
|
||||
if sample.dtype != self.sample_dtype:
|
||||
sample = tf.cast(sample, self.sample_dtype)
|
||||
out.append(sample)
|
||||
return tf.concat(out, axis=-1)
|
||||
|
||||
def sample(self):
|
||||
return self.tuple_sample_concat([pd.sample() for pd in self.pds])
|
||||
|
||||
def neglogp(self, x):
|
||||
return tf.add_n([pd.neglogp(xi) for pd, xi in zip(self.pds, decode_tuple_sample(self.pdtypes, x))])
|
||||
|
||||
def entropy(self):
|
||||
return tf.add_n([pd.entropy() for pd in self.pds])
|
||||
|
||||
|
||||
def _dtype_promotion(old, new):
|
||||
"""
|
||||
Find the highest precision common ground between two tensorflow datatypes.
|
||||
if old is None, it is ignored.
|
||||
"""
|
||||
if old is None or (new.is_floating and old.is_integer):
|
||||
return new
|
||||
if old.is_floating and old.is_integer:
|
||||
return old
|
||||
if (old.is_floating and new.is_floating) or (new.is_integer and new.is_integer):
|
||||
# take the largest type (e.g. float64 over float32)
|
||||
return old if old.size > new.size else new
|
||||
raise ValueError("No idea how to promote {} and {}.".format(old, new))
|
||||
|
||||
|
||||
class TuplePdType(PdType):
|
||||
def __init__(self, space):
|
||||
self.internal_pdtypes = [make_pdtype(space) for space in space.spaces]
|
||||
|
||||
def decode_sample(self, x):
|
||||
return decode_tuple_sample(self.internal_pdtypes, x)
|
||||
|
||||
def pdclass(self):
|
||||
return TuplePd
|
||||
|
||||
def pdfromflat(self, flat):
|
||||
return TuplePd(self.sample_dtype(), self.internal_pdtypes, flat)
|
||||
|
||||
def param_shape(self):
|
||||
return [sum([pdtype.param_shape()[0]
|
||||
for pdtype in self.internal_pdtypes])]
|
||||
|
||||
def sample_shape(self):
|
||||
return [sum([pdtype.sample_shape()[0] if len(pdtype.sample_shape()) > 0 else 1
|
||||
for pdtype in self.internal_pdtypes])]
|
||||
|
||||
def sample_dtype(self):
|
||||
dtype = None
|
||||
for pdtype in self.internal_pdtypes:
|
||||
dtype = _dtype_promotion(dtype, pdtype.sample_dtype())
|
||||
return dtype
|
||||
|
||||
|
||||
def make_pdtype(ac_space):
|
||||
from gym import spaces
|
||||
if isinstance(ac_space, spaces.Box):
|
||||
@@ -286,9 +413,12 @@ def make_pdtype(ac_space):
|
||||
return MultiCategoricalPdType(ac_space.nvec)
|
||||
elif isinstance(ac_space, spaces.MultiBinary):
|
||||
return BernoulliPdType(ac_space.n)
|
||||
elif isinstance(ac_space, spaces.Tuple):
|
||||
return TuplePdType(ac_space)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def shape_el(v, i):
|
||||
maybe = v.get_shape()[i]
|
||||
if maybe is not None:
|
||||
|
Reference in New Issue
Block a user