From 16136ddca76b466ebf6a9e363d1d28c694eb50c5 Mon Sep 17 00:00:00 2001 From: Mingfei Date: Tue, 2 Apr 2019 06:44:31 +0800 Subject: [PATCH] fix bugs: obs_ph normalization in adversary.py (#823) * fix bugs: obs_ph normalization in adversary.py * fix bug in reshape obs and acs in Mujobo_Dset --- baselines/gail/adversary.py | 2 +- baselines/gail/dataset/mujoco_dset.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/baselines/gail/adversary.py b/baselines/gail/adversary.py index 18df69c..96b8a4c 100644 --- a/baselines/gail/adversary.py +++ b/baselines/gail/adversary.py @@ -66,7 +66,7 @@ class TransitionClassifier(object): with tf.variable_scope("obfilter"): self.obs_rms = RunningMeanStd(shape=self.observation_shape) - obs = (obs_ph - self.obs_rms.mean / self.obs_rms.std) + obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std _input = tf.concat([obs, acs_ph], axis=1) # concatenate the two input -> form a transition p_h1 = tf.contrib.layers.fully_connected(_input, self.hidden_size, activation_fn=tf.nn.tanh) p_h2 = tf.contrib.layers.fully_connected(p_h1, self.hidden_size, activation_fn=tf.nn.tanh) diff --git a/baselines/gail/dataset/mujoco_dset.py b/baselines/gail/dataset/mujoco_dset.py index 0693262..2ada872 100644 --- a/baselines/gail/dataset/mujoco_dset.py +++ b/baselines/gail/dataset/mujoco_dset.py @@ -50,8 +50,12 @@ class Mujoco_Dset(object): # obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length # and S is the environment observation/action space. # Flatten to (N * L, prod(S)) - self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])]) - self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])]) + if len(obs.shape[2:]) != 0: + self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])]) + self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])]) + else: + self.obs = np.vstack(obs) + self.acs = np.vstack(acs) self.rets = traj_data['ep_rets'][:traj_limitation] self.avg_ret = sum(self.rets)/len(self.rets)