Report episode rewards/length in A2C and ACKTR (#856)
This commit is contained in:
@@ -11,6 +11,8 @@ from baselines.common.policies import build_policy
|
||||
|
||||
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
||||
from baselines.a2c.runner import Runner
|
||||
from baselines.ppo2.ppo2 import safemean
|
||||
from collections import deque
|
||||
|
||||
from tensorflow import losses
|
||||
|
||||
@@ -195,6 +197,7 @@ def learn(
|
||||
|
||||
# Instantiate the runner object
|
||||
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
||||
epinfobuf = deque(maxlen=100)
|
||||
|
||||
# Calculate the batch_size
|
||||
nbatch = nenvs*nsteps
|
||||
@@ -204,7 +207,8 @@ def learn(
|
||||
|
||||
for update in range(1, total_timesteps//nbatch+1):
|
||||
# Get mini batch of experiences
|
||||
obs, states, rewards, masks, actions, values = runner.run()
|
||||
obs, states, rewards, masks, actions, values, epinfos = runner.run()
|
||||
epinfobuf.extend(epinfos)
|
||||
|
||||
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
||||
nseconds = time.time()-tstart
|
||||
@@ -221,6 +225,8 @@ def learn(
|
||||
logger.record_tabular("policy_entropy", float(policy_entropy))
|
||||
logger.record_tabular("value_loss", float(value_loss))
|
||||
logger.record_tabular("explained_variance", float(ev))
|
||||
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
|
||||
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
|
||||
logger.dump_tabular()
|
||||
return model
|
||||
|
||||
|
@@ -22,6 +22,7 @@ class Runner(AbstractEnvRunner):
|
||||
# We initialize the lists that will contain the mb of experiences
|
||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
||||
mb_states = self.states
|
||||
epinfos = []
|
||||
for n in range(self.nsteps):
|
||||
# Given observations, take action and value (V(s))
|
||||
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
||||
@@ -34,7 +35,10 @@ class Runner(AbstractEnvRunner):
|
||||
mb_dones.append(self.dones)
|
||||
|
||||
# Take actions in env and look the results
|
||||
obs, rewards, dones, _ = self.env.step(actions)
|
||||
obs, rewards, dones, infos = self.env.step(actions)
|
||||
for info in infos:
|
||||
maybeepinfo = info.get('episode')
|
||||
if maybeepinfo: epinfos.append(maybeepinfo)
|
||||
self.states = states
|
||||
self.dones = dones
|
||||
self.obs = obs
|
||||
@@ -69,4 +73,4 @@ class Runner(AbstractEnvRunner):
|
||||
mb_rewards = mb_rewards.flatten()
|
||||
mb_values = mb_values.flatten()
|
||||
mb_masks = mb_masks.flatten()
|
||||
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values
|
||||
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos
|
||||
|
@@ -11,6 +11,8 @@ from baselines.common.tf_util import get_session, save_variables, load_variables
|
||||
from baselines.a2c.runner import Runner
|
||||
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
||||
from baselines.acktr import kfac
|
||||
from baselines.ppo2.ppo2 import safemean
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Model(object):
|
||||
@@ -118,6 +120,7 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
|
||||
model.load(load_path)
|
||||
|
||||
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
||||
epinfobuf = deque(maxlen=100)
|
||||
nbatch = nenvs*nsteps
|
||||
tstart = time.time()
|
||||
coord = tf.train.Coordinator()
|
||||
@@ -127,7 +130,8 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
|
||||
enqueue_threads = []
|
||||
|
||||
for update in range(1, total_timesteps//nbatch+1):
|
||||
obs, states, rewards, masks, actions, values = runner.run()
|
||||
obs, states, rewards, masks, actions, values, epinfos = runner.run()
|
||||
epinfobuf.extend(epinfos)
|
||||
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
||||
model.old_obs = obs
|
||||
nseconds = time.time()-tstart
|
||||
@@ -141,6 +145,8 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
|
||||
logger.record_tabular("policy_loss", float(policy_loss))
|
||||
logger.record_tabular("value_loss", float(value_loss))
|
||||
logger.record_tabular("explained_variance", float(ev))
|
||||
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
|
||||
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
|
||||
logger.dump_tabular()
|
||||
|
||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
|
||||
|
Reference in New Issue
Block a user