microbatch fixes and test (#169)
* microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix
This commit is contained in:
@@ -103,9 +103,9 @@ def test_coexistence(learn_fn, network_fn):
|
|||||||
kwargs.update(learn_kwargs[learn_fn])
|
kwargs.update(learn_kwargs[learn_fn])
|
||||||
|
|
||||||
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||||
make_session(make_default=True, graph=tf.Graph());
|
make_session(make_default=True, graph=tf.Graph())
|
||||||
model1 = learn(seed=1)
|
model1 = learn(seed=1)
|
||||||
make_session(make_default=True, graph=tf.Graph());
|
make_session(make_default=True, graph=tf.Graph())
|
||||||
model2 = learn(seed=2)
|
model2 = learn(seed=2)
|
||||||
|
|
||||||
model1.step(env.observation_space.sample())
|
model1.step(env.observation_space.sample())
|
||||||
|
@@ -131,27 +131,39 @@ class Model(object):
|
|||||||
|
|
||||||
# Normalize the advantages
|
# Normalize the advantages
|
||||||
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
advs = (advs - advs.mean()) / (advs.std() + 1e-8)
|
||||||
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
|
|
||||||
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
|
|
||||||
if states is not None:
|
|
||||||
td_map[train_model.S] = states
|
|
||||||
td_map[train_model.M] = masks
|
|
||||||
|
|
||||||
if microbatch_size == None or microbatch_size == obs.shape[0]:
|
if microbatch_size == None or microbatch_size == obs.shape[0]:
|
||||||
|
td_map = {train_model.X:obs, A:actions, ADV:advs, R:returns, LR:lr,
|
||||||
|
CLIPRANGE:cliprange, OLDNEGLOGPAC:neglogpacs, OLDVPRED:values}
|
||||||
|
if states is not None:
|
||||||
|
td_map[train_model.S] = states
|
||||||
|
td_map[train_model.M] = masks
|
||||||
|
|
||||||
return sess.run(
|
return sess.run(
|
||||||
[pg_loss, vf_loss, entropy, approxkl, clipfrac, _train],
|
[pg_loss, vf_loss, entropy, approxkl, clipfrac, _train],
|
||||||
td_map
|
td_map
|
||||||
)[:-1]
|
)[:-1]
|
||||||
else:
|
else:
|
||||||
sum_grad_v = []
|
assert states is None, "microbatches with recurrent models are not supported yet"
|
||||||
pg_losses = []
|
pg_losses = []
|
||||||
vf_losses = []
|
vf_losses = []
|
||||||
entropies = []
|
entropies = []
|
||||||
approx_kls = []
|
approx_kls = []
|
||||||
clipfracs = []
|
clipfracs = []
|
||||||
for _ in range(nmicrobatches):
|
for microbatch_idx in range(nmicrobatches):
|
||||||
|
_sli = range(microbatch_idx * microbatch_size, (microbatch_idx+1) * microbatch_size)
|
||||||
|
td_map = {
|
||||||
|
train_model.X: obs[_sli],
|
||||||
|
A:actions[_sli],
|
||||||
|
ADV:advs[_sli],
|
||||||
|
R:returns[_sli],
|
||||||
|
CLIPRANGE:cliprange,
|
||||||
|
OLDNEGLOGPAC:neglogpacs[_sli],
|
||||||
|
OLDVPRED:values[_sli]
|
||||||
|
}
|
||||||
|
|
||||||
grad_v, pg_loss_v, vf_loss_v, entropy_v, approx_kl_v, clipfrac_v = sess.run([grads, pg_loss, vf_loss, entropy, approxkl, clipfrac], td_map)
|
grad_v, pg_loss_v, vf_loss_v, entropy_v, approx_kl_v, clipfrac_v = sess.run([grads, pg_loss, vf_loss, entropy, approxkl, clipfrac], td_map)
|
||||||
if len(sum_grad_v) == 0:
|
if microbatch_idx == 0:
|
||||||
sum_grad_v = grad_v
|
sum_grad_v = grad_v
|
||||||
else:
|
else:
|
||||||
for i, g in enumerate(grad_v):
|
for i, g in enumerate(grad_v):
|
||||||
|
28
baselines/ppo2/test_microbatches.py
Normal file
28
baselines/ppo2/test_microbatches.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import gym
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
|
from baselines.common.tf_util import make_session
|
||||||
|
from baselines.ppo2.ppo2 import learn
|
||||||
|
|
||||||
|
def test_microbatches():
|
||||||
|
def env_fn():
|
||||||
|
env = gym.make('CartPole-v0')
|
||||||
|
env.seed(0)
|
||||||
|
return env
|
||||||
|
|
||||||
|
env_ref = DummyVecEnv([env_fn])
|
||||||
|
sess_ref = make_session(make_default=True, graph=tf.Graph())
|
||||||
|
learn(env=env_ref, network='mlp', nsteps=32, total_timesteps=32, seed=0)
|
||||||
|
vars_ref = {v.name: sess_ref.run(v) for v in tf.trainable_variables()}
|
||||||
|
|
||||||
|
env_test = DummyVecEnv([env_fn])
|
||||||
|
sess_test = make_session(make_default=True, graph=tf.Graph())
|
||||||
|
learn(env=env_test, network='mlp', nsteps=32, total_timesteps=32, seed=0, microbatch_size=4)
|
||||||
|
vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()}
|
||||||
|
|
||||||
|
for v in vars_ref:
|
||||||
|
np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_microbatches()
|
Reference in New Issue
Block a user