GAIL: bugfix in dataset loading (#447)

* Fix silly typo

* Replace ad-hoc function with NumPy code
This commit is contained in:
Adam Gleave
2018-07-06 16:12:14 -07:00
committed by pzhokhov
parent a6b1bc70f1
commit f272969325

View File

@@ -47,18 +47,12 @@ class Mujoco_Dset(object):
obs = traj_data['obs'][:traj_limitation]
acs = traj_data['acs'][:traj_limitation]
def flatten(x):
# x.shape = (E,), or (E, L, D)
_, size = x[0].shape
episode_length = [len(i) for i in x]
y = np.zeros((sum(episode_length), size))
start_idx = 0
for l, x_i in zip(episode_length, x):
y[start_idx:(start_idx+l)] = x_i
start_idx += l
return y
self.obs = np.array(flatten(obs))
self.acs = np.array(flatten(acs))
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
# and S is the environment observation/action space.
# Flatten to (N * L, prod(S))
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
self.rets = traj_data['ep_rets'][:traj_limitation]
self.avg_ret = sum(self.rets)/len(self.rets)
self.std_ret = np.std(np.array(self.rets))