remove redundant lines.
This commit is contained in:
@@ -6,6 +6,7 @@ import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.common.tests.envs.mnist_env import MnistEnv
|
||||
from baselines.common.tf_util import make_session, get_session
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
@@ -135,25 +136,3 @@ def _get_action_stats(model, ob):
|
||||
std = np.std(actions, axis=0)
|
||||
|
||||
return mean, std
|
||||
|
||||
if __name__ == '__main__':
|
||||
learn_kwargs = {
|
||||
'deepq': {},
|
||||
'a2c': {},
|
||||
'acktr': {},
|
||||
'acer': {},
|
||||
'ppo2': {'nminibatches': 1, 'nsteps': 10},
|
||||
'trpo_mpi': {},
|
||||
}
|
||||
|
||||
network_kwargs = {
|
||||
'mlp': {},
|
||||
'cnn': {'pad': 'SAME'},
|
||||
'lstm': {},
|
||||
'cnn_lnlstm': {'pad': 'SAME'}
|
||||
}
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
|
||||
# @pytest.mark.parametrize("network_fn", network_kwargs.keys())
|
||||
test_serialization('ppo2', 'cnn')
|
||||
|
Reference in New Issue
Block a user