microbatch fixes and test (#169)
* microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix
This commit is contained in:
@@ -103,9 +103,9 @@ def test_coexistence(learn_fn, network_fn):
|
||||
kwargs.update(learn_kwargs[learn_fn])
|
||||
|
||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
model1 = learn(seed=1)
|
||||
make_session(make_default=True, graph=tf.Graph());
|
||||
make_session(make_default=True, graph=tf.Graph())
|
||||
model2 = learn(seed=2)
|
||||
|
||||
model1.step(env.observation_space.sample())
|
||||
|
@@ -131,27 +131,39 @@ class Model(object):
|
||||
|
||||
# Normalize the advantages
|
||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
|
||||
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
|
||||
if states is not None:
|
||||
td_map[train_model.S] = states
|
||||
td_map[train_model.M] = masks
|
||||
|
||||
if microbatch_size == None or microbatch_size == obs.shape[0]:
|
||||
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
|
||||
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
|
||||
if states is not None:
|
||||
td_map[train_model.S] = states
|
||||
td_map[train_model.M] = masks
|
||||
|
||||
return sess.run(
|
||||
[pg_loss, vf_loss, entropy, approxkl, clipfrac, _train],
|
||||
td_map
|
||||
)[:-1]
|
||||
else:
|
||||
sum_grad_v = []
|
||||
assert states is None, "microbatches with recurrent models are not supported yet"
|
||||
pg_losses = []
|
||||
vf_losses = []
|
||||
entropies = []
|
||||
approx_kls = []
|
||||
clipfracs = []
|
||||
for _ in range(nmicrobatches):
|
||||
for microbatch_idx in range(nmicrobatches):
|
||||
_sli = range(microbatch_idx * microbatch_size, (microbatch_idx+1) * microbatch_size)
|
||||
td_map = {
|
||||
train_model.X: obs[_sli],
|
||||
A:actions[_sli],
|
||||
ADV:advs[_sli],
|
||||
R:returns[_sli],
|
||||
CLIPRANGE:cliprange,
|
||||
OLDNEGLOGPAC:neglogpacs[_sli],
|
||||
OLDVPRED:values[_sli]
|
||||
}
|
||||
|
||||
grad_v, pg_loss_v, vf_loss_v, entropy_v, approx_kl_v, clipfrac_v = sess.run([grads, pg_loss, vf_loss, entropy, approxkl, clipfrac], td_map)
|
||||
if len(sum_grad_v) == 0:
|
||||
if microbatch_idx == 0:
|
||||
sum_grad_v = grad_v
|
||||
else:
|
||||
for i, g in enumerate(grad_v):
|
||||
|
28
baselines/ppo2/test_microbatches.py
Normal file
28
baselines/ppo2/test_microbatches.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import gym
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common.tf_util import make_session
|
||||
from baselines.ppo2.ppo2 import learn
|
||||
|
||||
def test_microbatches():
|
||||
def env_fn():
|
||||
env = gym.make('CartPole-v0')
|
||||
env.seed(0)
|
||||
return env
|
||||
|
||||
env_ref = DummyVecEnv([env_fn])
|
||||
sess_ref = make_session(make_default=True, graph=tf.Graph())
|
||||
learn(env=env_ref, network='mlp', nsteps=32, total_timesteps=32, seed=0)
|
||||
vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()}
|
||||
|
||||
env_test = DummyVecEnv([env_fn])
|
||||
sess_test = make_session(make_default=True, graph=tf.Graph())
|
||||
learn(env=env_test, network='mlp', nsteps=32, total_timesteps=32, seed=0, microbatch_size=4)
|
||||
vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()}
|
||||
|
||||
for v in vars_ref:
|
||||
np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_microbatches()
|
Reference in New Issue
Block a user