clean the scope of ppo2 policy model.

This commit is contained in:
gyunt
2019-03-23 05:23:11 +09:00
parent a9d3b1c727
commit 06cef53de3

View File

@@ -41,19 +41,19 @@ class PolicyWithValue(object):
vf_latent = vf_latent if vf_latent is not None else latent
with tf.name_scope('action_space'):
with tf.variable_scope('policy'):
latent = tf.layers.flatten(latent)
# Based on the action space, will select what probability distribution type
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
with tf.name_scope('sample_action'):
with tf.variable_scope('sample_action'):
self.action = self.pd.sample()
with tf.name_scope('negative_log_probability'):
with tf.variable_scope('negative_log_probability'):
# Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action)
with tf.name_scope('value_estimator'):
with tf.variable_scope('value'):
vf_latent = tf.layers.flatten(vf_latent)
if estimate_q:
@@ -124,12 +124,12 @@ def build_ppo_policy(env, policy_network, value_network=None, estimate_q=False,
encoded_x = encode_observation(ob_space, X)
if is_rnn_network(policy_network):
policy_state, policy_network_ = policy_network(encoded_x, dones)
else:
policy_network_ = policy_network
with tf.variable_scope('load_rnn_memory'):
if is_rnn_network(policy_network):
policy_state, policy_network_ = policy_network(encoded_x, dones)
else:
policy_network_ = policy_network
if value_network == 'shared':
value_network_ = value_network
else: