diff --git a/baselines/gail/dataset/mujoco_dset.py b/baselines/gail/dataset/mujoco_dset.py index 2ada872..f5e9c27 100644 --- a/baselines/gail/dataset/mujoco_dset.py +++ b/baselines/gail/dataset/mujoco_dset.py @@ -50,7 +50,7 @@ class Mujoco_Dset(object): # obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length # and S is the environment observation/action space. # Flatten to (N * L, prod(S)) - if len(obs.shape[2:]) != 0: + if len(obs.shape) > 2: self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])]) self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])]) else: