match network output with action distribution via a linear layer only if necessary (#167)

This commit is contained in:
pzhokhov
2018-10-31 11:30:11 -07:00
committed by Peter Zhokhov
parent 5878eb3862
commit e619e42364

View File

@@ -62,7 +62,7 @@ class CategoricalPdType(PdType):
def pdclass(self): def pdclass(self):
return CategoricalPd return CategoricalPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias) pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam return self.pdfromflat(pdparam), pdparam
def param_shape(self): def param_shape(self):
@@ -82,7 +82,7 @@ class MultiCategoricalPdType(PdType):
return MultiCategoricalPd(self.ncats, flat) return MultiCategoricalPd(self.ncats, flat)
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias) pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam return self.pdfromflat(pdparam), pdparam
def param_shape(self): def param_shape(self):
@@ -99,7 +99,7 @@ class DiagGaussianPdType(PdType):
return DiagGaussianPd return DiagGaussianPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer()) logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1) pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
return self.pdfromflat(pdparam), mean return self.pdfromflat(pdparam), mean
@@ -123,7 +123,7 @@ class BernoulliPdType(PdType):
def sample_dtype(self): def sample_dtype(self):
return tf.int32 return tf.int32
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam return self.pdfromflat(pdparam), pdparam
# WRONG SECOND DERIVATIVES # WRONG SECOND DERIVATIVES
@@ -345,3 +345,9 @@ def validate_probtype(probtype, pdparam):
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
print('ok on', probtype, pdparam) print('ok on', probtype, pdparam)
def _matching_fc(tensor, name, size, init_scale, init_bias):
if tensor.shape[-1] == size:
return tensor
else:
return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)