fix shuffling bug in ppo1
This commit is contained in:
@@ -167,7 +167,7 @@ def learn(env, policy_fn, *,
|
||||
ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
|
||||
vpredbefore = seg["vpred"] # predicted value function before udpate
|
||||
atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
|
||||
d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
|
||||
d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), deterministic=pi.recurrent)
|
||||
optim_batchsize = optim_batchsize or ob.shape[0]
|
||||
|
||||
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
|
||||
|
Reference in New Issue
Block a user