codegen test fixes (#95)
* fix discovered test failures * autopep8 * test indices up to 123 * testing from index 124 on * add scope to logstd * fix flakiness in test_train_mle * autopep8
This commit is contained in:
@@ -28,6 +28,8 @@ class Pd(object):
|
|||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return self.get_shape()
|
return self.get_shape()
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.__class__(self.flatparam()[idx])
|
||||||
|
|
||||||
class PdType(object):
|
class PdType(object):
|
||||||
"""
|
"""
|
||||||
@@ -237,8 +239,6 @@ class DiagGaussianPd(Pd):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def fromflat(cls, flat):
|
def fromflat(cls, flat):
|
||||||
return cls(flat)
|
return cls(flat)
|
||||||
def __getitem__(self, idx):
|
|
||||||
return DiagGaussianPd(self.flat[idx])
|
|
||||||
|
|
||||||
|
|
||||||
class BernoulliPd(Pd):
|
class BernoulliPd(Pd):
|
||||||
@@ -246,7 +246,7 @@ class BernoulliPd(Pd):
|
|||||||
self.logits = logits
|
self.logits = logits
|
||||||
self.ps = tf.sigmoid(logits)
|
self.ps = tf.sigmoid(logits)
|
||||||
def flatparam(self):
|
def flatparam(self):
|
||||||
return self.logit
|
return self.logits
|
||||||
@property
|
@property
|
||||||
def mean(self):
|
def mean(self):
|
||||||
return self.ps
|
return self.ps
|
||||||
|
Reference in New Issue
Block a user