From 417c52bf5f7dba607dc761c23477d15581c429d4 Mon Sep 17 00:00:00 2001 From: gyunt Date: Mon, 8 Apr 2019 22:53:40 +0900 Subject: [PATCH] remove redundant lines. --- baselines/common/tests/test_serialization.py | 23 +------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/baselines/common/tests/test_serialization.py b/baselines/common/tests/test_serialization.py index 0a8d280..a770f93 100644 --- a/baselines/common/tests/test_serialization.py +++ b/baselines/common/tests/test_serialization.py @@ -6,6 +6,7 @@ import gym import numpy as np import pytest import tensorflow as tf + from baselines.common.tests.envs.mnist_env import MnistEnv from baselines.common.tf_util import make_session, get_session from baselines.common.vec_env.dummy_vec_env import DummyVecEnv @@ -135,25 +136,3 @@ def _get_action_stats(model, ob): std = np.std(actions, axis=0) return mean, std - -if __name__ == '__main__': - learn_kwargs = { - 'deepq': {}, - 'a2c': {}, - 'acktr': {}, - 'acer': {}, - 'ppo2': {'nminibatches': 1, 'nsteps': 10}, - 'trpo_mpi': {}, - } - - network_kwargs = { - 'mlp': {}, - 'cnn': {'pad': 'SAME'}, - 'lstm': {}, - 'cnn_lnlstm': {'pad': 'SAME'} - } - - - # @pytest.mark.parametrize("learn_fn", learn_kwargs.keys()) - # @pytest.mark.parametrize("network_fn", network_kwargs.keys()) - test_serialization('ppo2', 'cnn')