Merge pull request #120 from hamzamerzic/tensorflow_global_variable

Deprecated VARIABLES -> GLOBAL_VARIABLES.
This commit is contained in:
John Schulman
2017-08-28 21:27:23 -07:00
committed by GitHub
3 changed files with 8 additions and 8 deletions

View File

@@ -49,7 +49,7 @@ class CnnPolicy(object):
ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0]
def get_variables(self):
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self):

View File

@@ -51,7 +51,7 @@ class MlpPolicy(object):
ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0]
def get_variables(self):
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self):

View File

@@ -49,7 +49,7 @@ class CnnPolicy(object):
ac1, vpred1 = self._act(stochastic, ob[None])
return ac1[0], vpred1[0]
def get_variables(self):
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
def get_initial_state(self):