remove redundant lines.
This commit is contained in:
@@ -6,6 +6,7 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from baselines.common.tests.envs.mnist_env import MnistEnv
|
from baselines.common.tests.envs.mnist_env import MnistEnv
|
||||||
from baselines.common.tf_util import make_session, get_session
|
from baselines.common.tf_util import make_session, get_session
|
||||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
@@ -135,25 +136,3 @@ def _get_action_stats(model, ob):
|
|||||||
std = np.std(actions, axis=0)
|
std = np.std(actions, axis=0)
|
||||||
|
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
learn_kwargs = {
|
|
||||||
'deepq': {},
|
|
||||||
'a2c': {},
|
|
||||||
'acktr': {},
|
|
||||||
'acer': {},
|
|
||||||
'ppo2': {'nminibatches': 1, 'nsteps': 10},
|
|
||||||
'trpo_mpi': {},
|
|
||||||
}
|
|
||||||
|
|
||||||
network_kwargs = {
|
|
||||||
'mlp': {},
|
|
||||||
'cnn': {'pad': 'SAME'},
|
|
||||||
'lstm': {},
|
|
||||||
'cnn_lnlstm': {'pad': 'SAME'}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
|
|
||||||
# @pytest.mark.parametrize("network_fn", network_kwargs.keys())
|
|
||||||
test_serialization('ppo2', 'cnn')
|
|
||||||
|
Reference in New Issue
Block a user