Use standardized advantages in trpo.

This commit is contained in:
Jan Humplik
2017-07-23 22:42:55 +02:00
parent df82a15fd3
commit 4862140cea

View File

@@ -207,7 +207,7 @@ def learn(env, policy_func, *,
if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
args = seg["ob"], seg["ac"], seg["adv"]
args = seg["ob"], seg["ac"], atarg
fvpargs = [arr[::5] for arr in args]
def fisher_vector_product(p):
return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p
@@ -288,4 +288,4 @@ def learn(env, policy_func, *,
logger.dump_tabular()
def flatten_lists(listoflists):
return [el for list_ in listoflists for el in list_]
return [el for list_ in listoflists for el in list_]