ppo2 now has eval stats (#120)

* ppo2 now has eval stats

* fixed spaces

* fixed kwargs ordering

* whitespace fix
This commit is contained in:
Xingyou Song
2018-09-27 13:11:11 -07:00
committed by Peter Zhokhov
parent 858afa8d7e
commit e820b86fdc

View File

@@ -218,7 +218,7 @@ def constfn(val):
return val
return f
def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
save_interval=0, load_path=None, **network_kwargs):
@@ -307,8 +307,14 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
model.load(load_path)
# Instantiate the runner object
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
if eval_env is not None:
eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam)
epinfobuf = deque(maxlen=100)
if eval_env is not None:
eval_epinfobuf = deque(maxlen=100)
# Start total timer
tfirststart = time.time()
@@ -325,7 +331,12 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
cliprangenow = cliprange(frac)
# Get minibatch
obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
if eval_env is not None:
eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632
epinfobuf.extend(epinfos)
if eval_env is not None:
eval_epinfobuf.extend(eval_epinfos)
# Here what we're going to do is for each minibatch calculate the loss and append it.
mblossvals = []
@@ -375,6 +386,9 @@ def learn(*, network, env, total_timesteps, seed=None, nsteps=2048, ent_coef=0.0
logger.logkv("explained_variance", float(ev))
logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
if eval_env is not None:
logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]) )
logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]) )
logger.logkv('time_elapsed', tnow - tfirststart)
for (lossval, lossname) in zip(lossvals, model.loss_names):
logger.logkv(lossname, lossval)