Compare commits

..

4 Commits

Author SHA1 Message Date
Peter Zhokhov
2c818245d6 dummy commit to RUN BENCHMARKS 2018-07-25 18:09:30 -07:00
Peter Zhokhov
ae8e7fd16b dummy commit to RUN BENCHMARKS 2018-07-25 18:07:56 -07:00
Adam Gleave
f272969325 GAIL: bugfix in dataset loading (#447)
* Fix silly typo

* Replace ad-hoc function with NumPy code
2018-07-06 16:12:14 -07:00
pzhokhov
a6b1bc70f1 re-import internal; fix missing tile_images.py (#427)
* import rl-algs from 2e3a166 commit

* extra import of the baselines badge

* exported commit with identity test

* proper rng seeding in the test_identity

* import internal

* adding missing tile_images.py
2018-06-08 09:41:45 -07:00

View File

@@ -47,18 +47,12 @@ class Mujoco_Dset(object):
obs = traj_data['obs'][:traj_limitation]
acs = traj_data['acs'][:traj_limitation]
def flatten(x):
# x.shape = (E,), or (E, L, D)
_, size = x[0].shape
episode_length = [len(i) for i in x]
y = np.zeros((sum(episode_length), size))
start_idx = 0
for l, x_i in zip(episode_length, x):
y[start_idx:(start_idx+l)] = x_i
start_idx += l
return y
self.obs = np.array(flatten(obs))
self.acs = np.array(flatten(acs))
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
# and S is the environment observation/action space.
# Flatten to (N * L, prod(S))
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
self.rets = traj_data['ep_rets'][:traj_limitation]
self.avg_ret = sum(self.rets)/len(self.rets)
self.std_ret = np.std(np.array(self.rets))