change random seeding to work with new gym version (#231)

* change random seeding to work with new gym version

* move seeding to seed() method

* fix mnistenv

* actually try some of the tests before pushing

* more deterministic fixed seq
This commit is contained in:
Christopher Hesse
2019-02-04 21:10:11 -08:00
committed by Peter Zhokhov
parent 82ebd4a153
commit 0dcaafd717
7 changed files with 27 additions and 15 deletions

View File

@@ -7,21 +7,20 @@ class FixedSequenceEnv(Env):
def __init__( def __init__(
self, self,
n_actions=10, n_actions=10,
seed=0,
episode_len=100 episode_len=100
): ):
self.np_random = np.random.RandomState() self.np_random = np.random.RandomState()
self.np_random.seed(seed) self.sequence = None
self.sequence = [self.np_random.randint(0, n_actions-1) for _ in range(episode_len)]
self.action_space = Discrete(n_actions) self.action_space = Discrete(n_actions)
self.observation_space = Discrete(1) self.observation_space = Discrete(1)
self.episode_len = episode_len self.episode_len = episode_len
self.time = 0 self.time = 0
self.reset()
def reset(self): def reset(self):
if self.sequence is None:
self.sequence = [self.np_random.randint(0, self.action_space.n-1) for _ in range(self.episode_len)]
self.time = 0 self.time = 0
return 0 return 0
@@ -35,6 +34,9 @@ class FixedSequenceEnv(Env):
return 0, rew, done, {} return 0, rew, done, {}
def seed(self, seed=None):
self.np_random.seed(seed)
def _choose_next_state(self): def _choose_next_state(self):
self.time += 1 self.time += 1

View File

@@ -10,6 +10,7 @@ class IdentityEnv(Env):
episode_len=None episode_len=None
): ):
self.observation_space = self.action_space
self.episode_len = episode_len self.episode_len = episode_len
self.time = 0 self.time = 0
self.reset() self.reset()
@@ -17,7 +18,6 @@ class IdentityEnv(Env):
def reset(self): def reset(self):
self._choose_next_state() self._choose_next_state()
self.time = 0 self.time = 0
self.observation_space = self.action_space
return self.state return self.state
@@ -30,6 +30,9 @@ class IdentityEnv(Env):
return self.state, rew, done, {} return self.state, rew, done, {}
def seed(self, seed=None):
self.action_space.seed(seed)
def _choose_next_state(self): def _choose_next_state(self):
self.state = self.action_space.sample() self.state = self.action_space.sample()
self.time += 1 self.time += 1

View File

@@ -9,7 +9,6 @@ from gym.spaces import Discrete, Box
class MnistEnv(Env): class MnistEnv(Env):
def __init__( def __init__(
self, self,
seed=0,
episode_len=None, episode_len=None,
no_images=None no_images=None
): ):
@@ -23,7 +22,6 @@ class MnistEnv(Env):
self.mnist = input_data.read_data_sets(mnist_path) self.mnist = input_data.read_data_sets(mnist_path)
self.np_random = np.random.RandomState() self.np_random = np.random.RandomState()
self.np_random.seed(seed)
self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1)) self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
self.action_space = Discrete(10) self.action_space = Discrete(10)
@@ -50,6 +48,9 @@ class MnistEnv(Env):
return self.state[0], rew, done, {} return self.state[0], rew, done, {}
def seed(self, seed=None):
self.np_random.seed(seed)
def train_mode(self): def train_mode(self):
self.dataset = self.mnist.train self.dataset = self.mnist.train

View File

@@ -33,8 +33,7 @@ def test_fixed_sequence(alg, rnn):
kwargs = learn_kwargs[alg] kwargs = learn_kwargs[alg]
kwargs.update(common_kwargs) kwargs.update(common_kwargs)
episode_len = 5 env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
env_fn = lambda: FixedSequenceEnv(10, episode_len=episode_len)
learn = lambda e: get_learn_function(alg)( learn = lambda e: get_learn_function(alg)(
env=e, env=e,
network=rnn, network=rnn,

View File

@@ -41,7 +41,7 @@ def test_mnist(alg):
learn = get_learn_function(alg) learn = get_learn_function(alg)
learn_fn = lambda e: learn(env=e, **learn_kwargs) learn_fn = lambda e: learn(env=e, **learn_kwargs)
env_fn = lambda: MnistEnv(seed=0, episode_len=100) env_fn = lambda: MnistEnv(episode_len=100)
simple_test(env_fn, learn_fn, 0.6) simple_test(env_fn, learn_fn, 0.6)

View File

@@ -44,7 +44,12 @@ def test_serialization(learn_fn, network_fn):
# github issue: https://github.com/openai/baselines/issues/660 # github issue: https://github.com/openai/baselines/issues/660
return return
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)]) def make_env():
env = MnistEnv(episode_len=100)
env.seed(10)
return env
env = DummyVecEnv([make_env])
ob = env.reset().copy() ob = env.reset().copy()
learn = get_learn_function(learn_fn) learn = get_learn_function(learn_fn)

View File

@@ -1,17 +1,19 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from gym.spaces import np_random
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
N_TRIALS = 10000 N_TRIALS = 10000
N_EPISODES = 100 N_EPISODES = 100
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS): def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
def seeded_env_fn():
env = env_fn()
env.seed(0)
return env
np.random.seed(0) np.random.seed(0)
np_random.seed(0)
env = DummyVecEnv([env_fn])
env = DummyVecEnv([seeded_env_fn])
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default(): with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
tf.set_random_seed(0) tf.set_random_seed(0)