diff --git a/baselines/ppo2/test_microbatches.py b/baselines/ppo2/test_microbatches.py index 291c2d2..829e0a9 100644 --- a/baselines/ppo2/test_microbatches.py +++ b/baselines/ppo2/test_microbatches.py @@ -25,10 +25,11 @@ def test_microbatches(): env_test = DummyVecEnv([env_fn]) sess_test = make_session(make_default=True, graph=tf.Graph()) learn_fn(env=env_test, model_fn=partial(MicrobatchedModel, microbatch_size=2)) + # learn_fn(env=env_test) vars_test = {v.name: sess_test.run(v) for v in tf.trainable_variables()} for v in vars_ref: - np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=1e-3) + np.testing.assert_allclose(vars_ref[v], vars_test[v], atol=3e-3) if __name__ == '__main__': test_microbatches()