neaten up stacking logic in mujoco_dset in gail
This commit is contained in:
@@ -50,7 +50,7 @@ class Mujoco_Dset(object):
|
|||||||
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
|
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
|
||||||
# and S is the environment observation/action space.
|
# and S is the environment observation/action space.
|
||||||
# Flatten to (N * L, prod(S))
|
# Flatten to (N * L, prod(S))
|
||||||
if len(obs.shape[2:]) != 0:
|
if len(obs.shape) > 2:
|
||||||
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
|
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
|
||||||
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
|
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user