diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 5b3e7be..554a2f1 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -62,7 +62,7 @@ class CategoricalPdType(PdType): def pdclass(self): return CategoricalPd def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): - pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias) + pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias) return self.pdfromflat(pdparam), pdparam def param_shape(self): @@ -82,7 +82,7 @@ class MultiCategoricalPdType(PdType): return MultiCategoricalPd(self.ncats, flat) def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0): - pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias) + pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias) return self.pdfromflat(pdparam), pdparam def param_shape(self): @@ -99,7 +99,7 @@ class DiagGaussianPdType(PdType): return DiagGaussianPd def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): - mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) + mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer()) pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1) return self.pdfromflat(pdparam), mean @@ -123,7 +123,7 @@ class BernoulliPdType(PdType): def sample_dtype(self): return tf.int32 def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): - pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) + pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) return self.pdfromflat(pdparam), pdparam # WRONG SECOND DERIVATIVES @@ -345,3 +345,9 @@ def validate_probtype(probtype, pdparam): assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas print('ok on', probtype, pdparam) + +def _matching_fc(tensor, name, size, init_scale, init_bias): + if tensor.shape[-1] == size: + return tensor + else: + return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)