neaten up stacking logic in mujoco_dset in gail

This commit is contained in:
Peter Zhokhov
2019-04-01 15:47:13 -07:00
parent 16136ddca7
commit 096f4d9cf0

View File

@@ -50,7 +50,7 @@ class Mujoco_Dset(object):
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
# and S is the environment observation/action space.
# Flatten to (N * L, prod(S))
if len(obs.shape[2:]) != 0:
if len(obs.shape) > 2:
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
else: