define mean for CategoricalPd (as softmax of logits)

This commit is contained in:
Peter Zhokhov
2018-09-13 15:37:04 -07:00
parent fe06c6b4db
commit e790f5214b

View File

@@ -146,6 +146,10 @@ class CategoricalPd(Pd):
return self.logits
def mode(self):
return tf.argmax(self.logits, axis=-1)
@property
def mean(self):
return tf.nn.softmax(self.logits)
def neglogp(self, x):
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
# Note: we can't use sparse_softmax_cross_entropy_with_logits because