ACKTR + A2C

This commit is contained in:
John Schulman
2017-08-18 09:25:39 -07:00
parent 882251878f
commit 3f676f7d1e
31 changed files with 2920 additions and 144 deletions

View File

@@ -108,7 +108,7 @@ class BernoulliPdType(PdType):
# def flatparam(self):
# return self.logits
# def mode(self):
# return U.argmax(self.logits, axis=1)
# return U.argmax(self.logits, axis=-1)
# def logp(self, x):
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
# def kl(self, other):
@@ -118,7 +118,7 @@ class BernoulliPdType(PdType):
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def sample(self):
# u = tf.random_uniform(tf.shape(self.logits))
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
class CategoricalPd(Pd):
def __init__(self, logits):
@@ -126,27 +126,33 @@ class CategoricalPd(Pd):
def flatparam(self):
return self.logits
def mode(self):
return U.argmax(self.logits, axis=1)
return U.argmax(self.logits, axis=-1)
def neglogp(self, x):
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
# the implementation does not allow second-order derivatives...
one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
return tf.nn.softmax_cross_entropy_with_logits(
logits=self.logits,
labels=one_hot_actions)
def kl(self, other):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)
a0 = self.logits - U.max(self.logits, axis=-1, keepdims=True)
a1 = other.logits - U.max(other.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = U.sum(ea0, axis=1, keepdims=True)
z1 = U.sum(ea1, axis=1, keepdims=True)
z0 = U.sum(ea0, axis=-1, keepdims=True)
z1 = U.sum(ea1, axis=-1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)
return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)
def entropy(self):
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
a0 = self.logits - U.max(self.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
z0 = U.sum(ea0, axis=1, keepdims=True)
z0 = U.sum(ea0, axis=-1, keepdims=True)
p0 = ea0 / z0
return U.sum(p0 * (tf.log(z0) - a0), axis=1)
return U.sum(p0 * (tf.log(z0) - a0), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits))
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
@@ -177,7 +183,7 @@ class MultiCategoricalPd(Pd):
class DiagGaussianPd(Pd):
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(axis=len(flat.get_shape()) - 1, num_or_size_splits=2, value=flat)
mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
@@ -186,14 +192,14 @@ class DiagGaussianPd(Pd):
def mode(self):
return self.mean
def neglogp(self, x):
return 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=len(x.get_shape()) - 1) \
return 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=-1) \
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
+ U.sum(self.logstd, axis=len(x.get_shape()) - 1)
+ U.sum(self.logstd, axis=-1)
def kl(self, other):
assert isinstance(other, DiagGaussianPd)
return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
def entropy(self):
return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), -1)
return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
@classmethod
@@ -209,11 +215,11 @@ class BernoulliPd(Pd):
def mode(self):
return tf.round(self.ps)
def neglogp(self, x):
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=1)
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1)
def kl(self, other):
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def entropy(self):
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.ps))
return tf.to_float(math_ops.less(u, self.ps))
@@ -286,4 +292,3 @@ def validate_probtype(probtype, pdparam):
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas