From 401a89e515c995038bd4e5879c3f21731d4d4ab0 Mon Sep 17 00:00:00 2001 From: Jonathan Raiman Date: Thu, 23 May 2019 15:43:48 -0700 Subject: [PATCH] add tuple pdtype --- baselines/common/distributions.py | 130 ++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 0b5fc76..5b5ba80 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -275,6 +275,133 @@ class BernoulliPd(Pd): def fromflat(cls, flat): return cls(flat) + + +def _np_cast(x, dtype): + """Numpy cast, equivalent to tf.cast""" + return x.astype(dtype) + + +def decode_tuple_sample(pdtypes, x): + """ + Cast and convert a sample from its dense concatenated state back to constituent parts. + + Arguments + --------- + + :param pdtypes: list, a TuplePdType's child PdTypes. + :param x: np.ndarray or tf.Tensor. + Shape is [..., sum(pdtype.sample_shape for pdtype in pdtypes)] + + :return output, list or list, the split and correctly casted + policy samples. + """ + if isinstance(x, np.ndarray): + cast_fn = _np_cast + numpy_casting = True + else: + cast_fn = tf.cast + numpy_casting = False + + so_far = 0 + xs = [] + for pdtype in pdtypes: + sample_size = pdtype.sample_shape()[0] if len(pdtype.sample_shape()) > 0 else 1 + if len(pdtype.sample_shape()) == 0: + slided_x = x[..., so_far] + else: + slided_x = x[..., so_far:so_far + sample_size] + + desired_dtype = pdtype.sample_dtype() + if numpy_casting: + desired_dtype = desired_dtype.as_numpy_dtype + if desired_dtype != x: + slided_x = cast_fn(slided_x, desired_dtype) + xs.append(slided_x) + so_far += sample_size + return xs + + +class TuplePd(Pd): + def __init__(self, sample_dtype, pdtypes, logits): + self.pdtypes = pdtypes + self.sample_dtype = sample_dtype + self.pds = [] + so_far = 0 + for pdtype in self.pdtypes: + param_shape = pdtype.param_shape()[0] + self.pds.append(pdtype.pdfromflat(logits[..., so_far:so_far + param_shape])) + so_far += param_shape + + def flatparam(self): + return tf.concat([pd.flatparam() for pd in self.pds], axis=-1) + + def mode(self): + return self.tuple_sample_concat([pd.mode() for pd in self.pds]) + + def tuple_sample_concat(self, samples): + out = [] + for sample, pdtype in zip(samples, self.pdtypes): + if len(pdtype.sample_shape()) == 0: + sample = tf.expand_dims(sample, axis=-1) + if sample.dtype != self.sample_dtype: + sample = tf.cast(sample, self.sample_dtype) + out.append(sample) + return tf.concat(out, axis=-1) + + def sample(self): + return self.tuple_sample_concat([pd.sample() for pd in self.pds]) + + def neglogp(self, x): + return tf.add_n([pd.neglogp(xi) for pd, xi in zip(self.pds, decode_tuple_sample(self.pdtypes, x))]) + + def entropy(self): + return tf.add_n([pd.entropy() for pd in self.pds]) + + +def _dtype_promotion(old, new): + """ + Find the highest precision common ground between two tensorflow datatypes. + if old is None, it is ignored. + """ + if old is None or (new.is_floating and old.is_integer): + return new + if old.is_floating and old.is_integer: + return old + if (old.is_floating and new.is_floating) or (new.is_integer and new.is_integer): + # take the largest type (e.g. float64 over float32) + return old if old.size > new.size else new + raise ValueError("No idea how to promote {} and {}.".format(old, new)) + + +class TuplePdType(PdType): + def __init__(self, space): + self.internal_pdtypes = [make_pdtype(space) for space in space.spaces] + + def decode_sample(self, x): + return decode_tuple_sample(self.internal_pdtypes, x) + + def pdclass(self): + return TuplePd + + def pdfromflat(self, flat): + return TuplePd(self.sample_dtype(), self.internal_pdtypes, flat) + + def param_shape(self): + return [sum([pdtype.param_shape()[0] + for pdtype in self.internal_pdtypes])] + + def sample_shape(self): + return [sum([pdtype.sample_shape()[0] if len(pdtype.sample_shape()) > 0 else 1 + for pdtype in self.internal_pdtypes])] + + def sample_dtype(self): + dtype = None + for pdtype in self.internal_pdtypes: + dtype = _dtype_promotion(dtype, pdtype.sample_dtype()) + return dtype + + def make_pdtype(ac_space): from gym import spaces if isinstance(ac_space, spaces.Box): @@ -286,9 +413,12 @@ def make_pdtype(ac_space): return MultiCategoricalPdType(ac_space.nvec) elif isinstance(ac_space, spaces.MultiBinary): return BernoulliPdType(ac_space.n) + elif isinstance(ac_space, spaces.Tuple): + return TuplePdType(ac_space) else: raise NotImplementedError + def shape_el(v, i): maybe = v.get_shape()[i] if maybe is not None: