codegen test fixes (#95)
* fix discovered test failures * autopep8 * test indices up to 123 * testing from index 124 on * add scope to logstd * fix flakiness in test_train_mle * autopep8
This commit is contained in:
@@ -28,6 +28,8 @@ class Pd(object):
|
||||
@property
|
||||
def shape(self):
|
||||
return self.get_shape()
|
||||
def __getitem__(self, idx):
|
||||
return self.__class__(self.flatparam()[idx])
|
||||
|
||||
class PdType(object):
|
||||
"""
|
||||
@@ -237,8 +239,6 @@ class DiagGaussianPd(Pd):
|
||||
@classmethod
|
||||
def fromflat(cls, flat):
|
||||
return cls(flat)
|
||||
def __getitem__(self, idx):
|
||||
return DiagGaussianPd(self.flat[idx])
|
||||
|
||||
|
||||
class BernoulliPd(Pd):
|
||||
@@ -246,7 +246,7 @@ class BernoulliPd(Pd):
|
||||
self.logits = logits
|
||||
self.ps = tf.sigmoid(logits)
|
||||
def flatparam(self):
|
||||
return self.logit
|
||||
return self.logits
|
||||
@property
|
||||
def mean(self):
|
||||
return self.ps
|
||||
|
Reference in New Issue
Block a user