fix shuffling bug in ppo1

This commit is contained in:
Peter Zhokhov
2019-04-05 15:23:46 -07:00
parent fabbf2c611
commit 8a97e0df10

View File

@@ -167,7 +167,7 @@ def learn(env, policy_fn, *,
ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
vpredbefore = seg["vpred"] # predicted value function before udpate
atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), deterministic=pi.recurrent)
optim_batchsize = optim_batchsize or ob.shape[0]
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy