microbatch fixes and test (#169)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix
This commit is contained in:
pzhokhov
2018-11-01 14:54:49 -07:00
committed by Peter Zhokhov
parent a1cef656b8
commit c424f9889d
3 changed files with 50 additions and 10 deletions

View File

@@ -103,9 +103,9 @@ def test_coexistence(learn_fn, network_fn):
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
make_session(make_default=True, graph=tf.Graph());
make_session(make_default=True, graph=tf.Graph())
model1 = learn(seed=1)
make_session(make_default=True, graph=tf.Graph());
make_session(make_default=True, graph=tf.Graph())
model2 = learn(seed=2)
model1.step(env.observation_space.sample())

View File

@@ -131,27 +131,39 @@ class Model(object):
# Normalize the advantages
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
if microbatch_size == None or microbatch_size == obs.shape[0]:
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
if states is not None:
td_map[train_model.S] = states
td_map[train_model.M] = masks
if microbatch_size == None or microbatch_size == obs.shape[0]:
return sess.run(
[pg_loss, vf_loss, entropy, approxkl, clipfrac, _train],
td_map
)[:-1]
else:
sum_grad_v = []
assert states is None, "microbatches with recurrent models are not supported yet"
pg_losses = []
vf_losses = []
entropies = []
approx_kls = []
clipfracs = []
for _ in range(nmicrobatches):
for microbatch_idx in range(nmicrobatches):
_sli = range(microbatch_idx * microbatch_size, (microbatch_idx+1) * microbatch_size)
td_map = {
train_model.X: obs[_sli],
A:actions[_sli],
ADV:advs[_sli],
R:returns[_sli],
CLIPRANGE:cliprange,
OLDNEGLOGPAC:neglogpacs[_sli],
OLDVPRED:values[_sli]
}
grad_v, pg_loss_v, vf_loss_v, entropy_v, approx_kl_v, clipfrac_v = sess.run([grads, pg_loss, vf_loss, entropy, approxkl, clipfrac], td_map)
if len(sum_grad_v) == 0:
if microbatch_idx == 0:
sum_grad_v = grad_v
else:
for i, g in enumerate(grad_v):

View File

@@ -0,0 +1,28 @@
import gym
import tensorflow as tf
import numpy as np
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.tf_util import make_session
from baselines.ppo2.ppo2 import learn
def test_microbatches():
def env_fn():
env = gym.make('CartPole-v0')
env.seed(0)
return env
env_ref = DummyVecEnv([env_fn])
sess_ref = make_session(make_default=True, graph=tf.Graph())
learn(env=env_ref, network='mlp', nsteps=32, total_timesteps=32, seed=0)
vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()}
env_test = DummyVecEnv([env_fn])
sess_test = make_session(make_default=True, graph=tf.Graph())
learn(env=env_test, network='mlp', nsteps=32, total_timesteps=32, seed=0, microbatch_size=4)
vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()}
for v in vars_ref:
np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3)
if __name__ == '__main__':
test_microbatches()