change random seeding to work with new gym version (#231)
* change random seeding to work with new gym version * move seeding to seed() method * fix mnistenv * actually try some of the tests before pushing * more deterministic fixed seq
This commit is contained in:
committed by
Peter Zhokhov
parent
82ebd4a153
commit
0dcaafd717
@@ -7,21 +7,20 @@ class FixedSequenceEnv(Env):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_actions=10,
|
n_actions=10,
|
||||||
seed=0,
|
|
||||||
episode_len=100
|
episode_len=100
|
||||||
):
|
):
|
||||||
self.np_random = np.random.RandomState()
|
self.np_random = np.random.RandomState()
|
||||||
self.np_random.seed(seed)
|
self.sequence = None
|
||||||
self.sequence = [self.np_random.randint(0, n_actions-1) for _ in range(episode_len)]
|
|
||||||
|
|
||||||
self.action_space = Discrete(n_actions)
|
self.action_space = Discrete(n_actions)
|
||||||
self.observation_space = Discrete(1)
|
self.observation_space = Discrete(1)
|
||||||
|
|
||||||
self.episode_len = episode_len
|
self.episode_len = episode_len
|
||||||
self.time = 0
|
self.time = 0
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
if self.sequence is None:
|
||||||
|
self.sequence = [self.np_random.randint(0, self.action_space.n-1) for _ in range(self.episode_len)]
|
||||||
self.time = 0
|
self.time = 0
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@@ -35,6 +34,9 @@ class FixedSequenceEnv(Env):
|
|||||||
|
|
||||||
return 0, rew, done, {}
|
return 0, rew, done, {}
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
self.np_random.seed(seed)
|
||||||
|
|
||||||
def _choose_next_state(self):
|
def _choose_next_state(self):
|
||||||
self.time += 1
|
self.time += 1
|
||||||
|
|
||||||
|
@@ -10,6 +10,7 @@ class IdentityEnv(Env):
|
|||||||
episode_len=None
|
episode_len=None
|
||||||
):
|
):
|
||||||
|
|
||||||
|
self.observation_space = self.action_space
|
||||||
self.episode_len = episode_len
|
self.episode_len = episode_len
|
||||||
self.time = 0
|
self.time = 0
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -17,7 +18,6 @@ class IdentityEnv(Env):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
self._choose_next_state()
|
self._choose_next_state()
|
||||||
self.time = 0
|
self.time = 0
|
||||||
self.observation_space = self.action_space
|
|
||||||
|
|
||||||
return self.state
|
return self.state
|
||||||
|
|
||||||
@@ -30,6 +30,9 @@ class IdentityEnv(Env):
|
|||||||
|
|
||||||
return self.state, rew, done, {}
|
return self.state, rew, done, {}
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
self.action_space.seed(seed)
|
||||||
|
|
||||||
def _choose_next_state(self):
|
def _choose_next_state(self):
|
||||||
self.state = self.action_space.sample()
|
self.state = self.action_space.sample()
|
||||||
self.time += 1
|
self.time += 1
|
||||||
|
@@ -9,7 +9,6 @@ from gym.spaces import Discrete, Box
|
|||||||
class MnistEnv(Env):
|
class MnistEnv(Env):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seed=0,
|
|
||||||
episode_len=None,
|
episode_len=None,
|
||||||
no_images=None
|
no_images=None
|
||||||
):
|
):
|
||||||
@@ -23,7 +22,6 @@ class MnistEnv(Env):
|
|||||||
self.mnist = input_data.read_data_sets(mnist_path)
|
self.mnist = input_data.read_data_sets(mnist_path)
|
||||||
|
|
||||||
self.np_random = np.random.RandomState()
|
self.np_random = np.random.RandomState()
|
||||||
self.np_random.seed(seed)
|
|
||||||
|
|
||||||
self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
|
self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
|
||||||
self.action_space = Discrete(10)
|
self.action_space = Discrete(10)
|
||||||
@@ -50,6 +48,9 @@ class MnistEnv(Env):
|
|||||||
|
|
||||||
return self.state[0], rew, done, {}
|
return self.state[0], rew, done, {}
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
self.np_random.seed(seed)
|
||||||
|
|
||||||
def train_mode(self):
|
def train_mode(self):
|
||||||
self.dataset = self.mnist.train
|
self.dataset = self.mnist.train
|
||||||
|
|
||||||
|
@@ -33,8 +33,7 @@ def test_fixed_sequence(alg, rnn):
|
|||||||
kwargs = learn_kwargs[alg]
|
kwargs = learn_kwargs[alg]
|
||||||
kwargs.update(common_kwargs)
|
kwargs.update(common_kwargs)
|
||||||
|
|
||||||
episode_len = 5
|
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
|
||||||
env_fn = lambda: FixedSequenceEnv(10, episode_len=episode_len)
|
|
||||||
learn = lambda e: get_learn_function(alg)(
|
learn = lambda e: get_learn_function(alg)(
|
||||||
env=e,
|
env=e,
|
||||||
network=rnn,
|
network=rnn,
|
||||||
|
@@ -41,7 +41,7 @@ def test_mnist(alg):
|
|||||||
|
|
||||||
learn = get_learn_function(alg)
|
learn = get_learn_function(alg)
|
||||||
learn_fn = lambda e: learn(env=e, **learn_kwargs)
|
learn_fn = lambda e: learn(env=e, **learn_kwargs)
|
||||||
env_fn = lambda: MnistEnv(seed=0, episode_len=100)
|
env_fn = lambda: MnistEnv(episode_len=100)
|
||||||
|
|
||||||
simple_test(env_fn, learn_fn, 0.6)
|
simple_test(env_fn, learn_fn, 0.6)
|
||||||
|
|
||||||
|
@@ -44,7 +44,12 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
# github issue: https://github.com/openai/baselines/issues/660
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
return
|
return
|
||||||
|
|
||||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
def make_env():
|
||||||
|
env = MnistEnv(episode_len=100)
|
||||||
|
env.seed(10)
|
||||||
|
return env
|
||||||
|
|
||||||
|
env = DummyVecEnv([make_env])
|
||||||
ob = env.reset().copy()
|
ob = env.reset().copy()
|
||||||
learn = get_learn_function(learn_fn)
|
learn = get_learn_function(learn_fn)
|
||||||
|
|
||||||
|
@@ -1,17 +1,19 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.spaces import np_random
|
|
||||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
|
|
||||||
N_TRIALS = 10000
|
N_TRIALS = 10000
|
||||||
N_EPISODES = 100
|
N_EPISODES = 100
|
||||||
|
|
||||||
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||||
|
def seeded_env_fn():
|
||||||
|
env = env_fn()
|
||||||
|
env.seed(0)
|
||||||
|
return env
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
np_random.seed(0)
|
|
||||||
|
|
||||||
env = DummyVecEnv([env_fn])
|
|
||||||
|
|
||||||
|
env = DummyVecEnv([seeded_env_fn])
|
||||||
|
|
||||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||||
tf.set_random_seed(0)
|
tf.set_random_seed(0)
|
||||||
|
Reference in New Issue
Block a user