define mean for CategoricalPd (as softmax of logits)
This commit is contained in:
@@ -146,6 +146,10 @@ class CategoricalPd(Pd):
|
||||
return self.logits
|
||||
def mode(self):
|
||||
return tf.argmax(self.logits, axis=-1)
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return tf.nn.softmax(self.logits)
|
||||
def neglogp(self, x):
|
||||
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
|
||||
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
|
||||
|
Reference in New Issue
Block a user