fix bugs: obs_ph normalization in adversary.py (#823)

* fix bugs: obs_ph normalization in adversary.py

* fix bug in reshape obs and acs in Mujobo_Dset
This commit is contained in:
Mingfei
2019-04-02 06:44:31 +08:00
committed by pzhokhov
parent b1644157d6
commit 16136ddca7
2 changed files with 7 additions and 3 deletions

View File

@@ -66,7 +66,7 @@ class TransitionClassifier(object):
with tf.variable_scope("obfilter"): with tf.variable_scope("obfilter"):
self.obs_rms = RunningMeanStd(shape=self.observation_shape) self.obs_rms = RunningMeanStd(shape=self.observation_shape)
obs = (obs_ph - self.obs_rms.mean / self.obs_rms.std) obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std
_input = tf.concat([obs, acs_ph], axis=1) # concatenate the two input -> form a transition _input = tf.concat([obs, acs_ph], axis=1) # concatenate the two input -> form a transition
p_h1 = tf.contrib.layers.fully_connected(_input, self.hidden_size, activation_fn=tf.nn.tanh) p_h1 = tf.contrib.layers.fully_connected(_input, self.hidden_size, activation_fn=tf.nn.tanh)
p_h2 = tf.contrib.layers.fully_connected(p_h1, self.hidden_size, activation_fn=tf.nn.tanh) p_h2 = tf.contrib.layers.fully_connected(p_h1, self.hidden_size, activation_fn=tf.nn.tanh)

View File

@@ -50,8 +50,12 @@ class Mujoco_Dset(object):
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length # obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
# and S is the environment observation/action space. # and S is the environment observation/action space.
# Flatten to (N * L, prod(S)) # Flatten to (N * L, prod(S))
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])]) if len(obs.shape[2:]) != 0:
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])]) self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
else:
self.obs = np.vstack(obs)
self.acs = np.vstack(acs)
self.rets = traj_data['ep_rets'][:traj_limitation] self.rets = traj_data['ep_rets'][:traj_limitation]
self.avg_ret = sum(self.rets)/len(self.rets) self.avg_ret = sum(self.rets)/len(self.rets)