diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index ab94ca0..8366eb5 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -28,6 +28,8 @@ class Pd(object): @property def shape(self): return self.get_shape() + def __getitem__(self, idx): + return self.__class__(self.flatparam()[idx]) class PdType(object): """ @@ -237,8 +239,6 @@ class DiagGaussianPd(Pd): @classmethod def fromflat(cls, flat): return cls(flat) - def __getitem__(self, idx): - return DiagGaussianPd(self.flat[idx]) class BernoulliPd(Pd): @@ -246,7 +246,7 @@ class BernoulliPd(Pd): self.logits = logits self.ps = tf.sigmoid(logits) def flatparam(self): - return self.logit + return self.logits @property def mean(self): return self.ps