From c424f9889dfe51676ff634418330bc5a6e74b183 Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Thu, 1 Nov 2018 14:54:49 -0700 Subject: [PATCH] microbatch fixes and test (#169) * microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix --- baselines/common/tests/test_serialization.py | 4 +-- baselines/ppo2/ppo2.py | 28 ++++++++++++++------ baselines/ppo2/test_microbatches.py | 28 ++++++++++++++++++++ 3 files changed, 50 insertions(+), 10 deletions(-) create mode 100644 baselines/ppo2/test_microbatches.py diff --git a/baselines/common/tests/test_serialization.py b/baselines/common/tests/test_serialization.py index fac4929..73d29e9 100644 --- a/baselines/common/tests/test_serialization.py +++ b/baselines/common/tests/test_serialization.py @@ -103,9 +103,9 @@ def test_coexistence(learn_fn, network_fn): kwargs.update(learn_kwargs[learn_fn]) learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs) - make_session(make_default=True, graph=tf.Graph()); + make_session(make_default=True, graph=tf.Graph()) model1 = learn(seed=1) - make_session(make_default=True, graph=tf.Graph()); + make_session(make_default=True, graph=tf.Graph()) model2 = learn(seed=2) model1.step(env.observation_space.sample()) diff --git a/baselines/ppo2/ppo2.py b/baselines/ppo2/ppo2.py index a0adadc..eaf5a72 100644 --- a/baselines/ppo2/ppo2.py +++ b/baselines/ppo2/ppo2.py @@ -131,27 +131,39 @@ class Model(object): # Normalize the advantages advs = (advs - advs.mean()) / (advs.std() + 1e-8) - td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr, - CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values} - if states is not None: - td_map[train_model.S] = states - td_map[train_model.M] = masks if microbatch_size == None or microbatch_size == obs.shape[0]: + td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr, + CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values} + if states is not None: + td_map[train_model.S] = states + td_map[train_model.M] = masks + return sess.run( [pg_loss, vf_loss, entropy, approxkl, clipfrac, _train], td_map )[:-1] else: - sum_grad_v = [] + assert states is None, "microbatches with recurrent models are not supported yet" pg_losses = [] vf_losses = [] entropies = [] approx_kls = [] clipfracs = [] - for _ in range(nmicrobatches): + for microbatch_idx in range(nmicrobatches): + _sli = range(microbatch_idx * microbatch_size, (microbatch_idx+1) * microbatch_size) + td_map = { + train_model.X: obs[_sli], + A:actions[_sli], + ADV:advs[_sli], + R:returns[_sli], + CLIPRANGE:cliprange, + OLDNEGLOGPAC:neglogpacs[_sli], + OLDVPRED:values[_sli] + } + grad_v, pg_loss_v, vf_loss_v, entropy_v, approx_kl_v, clipfrac_v = sess.run([grads, pg_loss, vf_loss, entropy, approxkl, clipfrac], td_map) - if len(sum_grad_v) == 0: + if microbatch_idx == 0: sum_grad_v = grad_v else: for i, g in enumerate(grad_v): diff --git a/baselines/ppo2/test_microbatches.py b/baselines/ppo2/test_microbatches.py new file mode 100644 index 0000000..86a9cec --- /dev/null +++ b/baselines/ppo2/test_microbatches.py @@ -0,0 +1,28 @@ +import gym +import tensorflow as tf +import numpy as np +from baselines.common.vec_env.dummy_vec_env import DummyVecEnv +from baselines.common.tf_util import make_session +from baselines.ppo2.ppo2 import learn + +def test_microbatches(): + def env_fn(): + env = gym.make('CartPole-v0') + env.seed(0) + return env + + env_ref = DummyVecEnv([env_fn]) + sess_ref = make_session(make_default=True, graph=tf.Graph()) + learn(env=env_ref, network='mlp', nsteps=32, total_timesteps=32, seed=0) + vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()} + + env_test = DummyVecEnv([env_fn]) + sess_test = make_session(make_default=True, graph=tf.Graph()) + learn(env=env_test, network='mlp', nsteps=32, total_timesteps=32, seed=0, microbatch_size=4) + vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()} + + for v in vars_ref: + np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3) + +if __name__ == '__main__': + test_microbatches()