diff --git a/baselines/trpo_mpi/trpo_mpi.py b/baselines/trpo_mpi/trpo_mpi.py index 92579ef..ef681c4 100644 --- a/baselines/trpo_mpi/trpo_mpi.py +++ b/baselines/trpo_mpi/trpo_mpi.py @@ -207,7 +207,7 @@ def learn(env, policy_func, *, if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret) if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy - args = seg["ob"], seg["ac"], seg["adv"] + args = seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] def fisher_vector_product(p): return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p @@ -288,4 +288,4 @@ def learn(env, policy_func, *, logger.dump_tabular() def flatten_lists(listoflists): - return [el for list_ in listoflists for el in list_] \ No newline at end of file + return [el for list_ in listoflists for el in list_]