change random seeding to work with new gym version (#231)
* change random seeding to work with new gym version * move seeding to seed() method * fix mnistenv * actually try some of the tests before pushing * more deterministic fixed seq
This commit is contained in:
committed by
Peter Zhokhov
parent
82ebd4a153
commit
0dcaafd717
@@ -7,21 +7,20 @@ class FixedSequenceEnv(Env):
|
||||
def __init__(
|
||||
self,
|
||||
n_actions=10,
|
||||
seed=0,
|
||||
episode_len=100
|
||||
):
|
||||
self.np_random = np.random.RandomState()
|
||||
self.np_random.seed(seed)
|
||||
self.sequence = [self.np_random.randint(0, n_actions-1) for _ in range(episode_len)]
|
||||
self.sequence = None
|
||||
|
||||
self.action_space = Discrete(n_actions)
|
||||
self.observation_space = Discrete(1)
|
||||
|
||||
self.episode_len = episode_len
|
||||
self.time = 0
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
if self.sequence is None:
|
||||
self.sequence = [self.np_random.randint(0, self.action_space.n-1) for _ in range(self.episode_len)]
|
||||
self.time = 0
|
||||
return 0
|
||||
|
||||
@@ -35,6 +34,9 @@ class FixedSequenceEnv(Env):
|
||||
|
||||
return 0, rew, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random.seed(seed)
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.time += 1
|
||||
|
||||
|
@@ -10,6 +10,7 @@ class IdentityEnv(Env):
|
||||
episode_len=None
|
||||
):
|
||||
|
||||
self.observation_space = self.action_space
|
||||
self.episode_len = episode_len
|
||||
self.time = 0
|
||||
self.reset()
|
||||
@@ -17,7 +18,6 @@ class IdentityEnv(Env):
|
||||
def reset(self):
|
||||
self._choose_next_state()
|
||||
self.time = 0
|
||||
self.observation_space = self.action_space
|
||||
|
||||
return self.state
|
||||
|
||||
@@ -30,6 +30,9 @@ class IdentityEnv(Env):
|
||||
|
||||
return self.state, rew, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.action_space.seed(seed)
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.state = self.action_space.sample()
|
||||
self.time += 1
|
||||
|
@@ -9,7 +9,6 @@ from gym.spaces import Discrete, Box
|
||||
class MnistEnv(Env):
|
||||
def __init__(
|
||||
self,
|
||||
seed=0,
|
||||
episode_len=None,
|
||||
no_images=None
|
||||
):
|
||||
@@ -23,7 +22,6 @@ class MnistEnv(Env):
|
||||
self.mnist = input_data.read_data_sets(mnist_path)
|
||||
|
||||
self.np_random = np.random.RandomState()
|
||||
self.np_random.seed(seed)
|
||||
|
||||
self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
|
||||
self.action_space = Discrete(10)
|
||||
@@ -50,6 +48,9 @@ class MnistEnv(Env):
|
||||
|
||||
return self.state[0], rew, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random.seed(seed)
|
||||
|
||||
def train_mode(self):
|
||||
self.dataset = self.mnist.train
|
||||
|
||||
|
@@ -33,8 +33,7 @@ def test_fixed_sequence(alg, rnn):
|
||||
kwargs = learn_kwargs[alg]
|
||||
kwargs.update(common_kwargs)
|
||||
|
||||
episode_len = 5
|
||||
env_fn = lambda: FixedSequenceEnv(10, episode_len=episode_len)
|
||||
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
|
||||
learn = lambda e: get_learn_function(alg)(
|
||||
env=e,
|
||||
network=rnn,
|
||||
|
@@ -41,7 +41,7 @@ def test_mnist(alg):
|
||||
|
||||
learn = get_learn_function(alg)
|
||||
learn_fn = lambda e: learn(env=e, **learn_kwargs)
|
||||
env_fn = lambda: MnistEnv(seed=0, episode_len=100)
|
||||
env_fn = lambda: MnistEnv(episode_len=100)
|
||||
|
||||
simple_test(env_fn, learn_fn, 0.6)
|
||||
|
||||
|
@@ -44,7 +44,12 @@ def test_serialization(learn_fn, network_fn):
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
|
||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
||||
def make_env():
|
||||
env = MnistEnv(episode_len=100)
|
||||
env.seed(10)
|
||||
return env
|
||||
|
||||
env = DummyVecEnv([make_env])
|
||||
ob = env.reset().copy()
|
||||
learn = get_learn_function(learn_fn)
|
||||
|
||||
|
@@ -1,17 +1,19 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from gym.spaces import np_random
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
|
||||
N_TRIALS = 10000
|
||||
N_EPISODES = 100
|
||||
|
||||
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||
def seeded_env_fn():
|
||||
env = env_fn()
|
||||
env.seed(0)
|
||||
return env
|
||||
|
||||
np.random.seed(0)
|
||||
np_random.seed(0)
|
||||
|
||||
env = DummyVecEnv([env_fn])
|
||||
|
||||
env = DummyVecEnv([seeded_env_fn])
|
||||
|
||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||
tf.set_random_seed(0)
|
||||
|
Reference in New Issue
Block a user