Merge branch 'master' of github.com:openai/baselines into internal

This commit is contained in:
Peter Zhokhov
2018-10-19 09:52:23 -07:00
7 changed files with 52 additions and 10 deletions

View File

@@ -53,7 +53,7 @@ class PolicyWithValue(object):
# Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action)
self.sess = sess
self.sess = sess or tf.get_default_session()
if estimate_q:
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -64,7 +64,7 @@ class PolicyWithValue(object):
self.vf = self.vf[:,0]
def _evaluate(self, variables, observation, **extra_feed):
sess = self.sess or tf.get_default_session()
sess = self.sess
feed_dict = {self.X: adjust_shape(self.X, observation)}
for inpt_name, data in extra_feed.items():
if inpt_name in self.__dict__.keys():

View File

@@ -1,7 +1,6 @@
import os.path as osp
import numpy as np
import tempfile
import filelock
from gym import Env
from gym.spaces import Discrete, Box
@@ -14,6 +13,7 @@ class MnistEnv(Env):
episode_len=None,
no_images=None
):
import filelock
from tensorflow.examples.tutorials.mnist import input_data
# we could use temporary directory for this with a context manager and
# TemporaryDirecotry, but then each test that uses mnist would re-download the data

View File

@@ -1,4 +1,5 @@
import os
import gym
import tempfile
import pytest
import tensorflow as tf
@@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn):
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/194
# github issue: https://github.com/openai/baselines/issues/660
return
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
@@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn):
np.testing.assert_allclose(std1, std2, atol=0.5)
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
@pytest.mark.parametrize("network_fn", ['mlp'])
def test_coexistence(learn_fn, network_fn):
'''
Test if more than one model can exist at a time
'''
if learn_fn == 'deepq':
# TODO enable multiple DQN models to be useable at the same time
# github issue https://github.com/openai/baselines/issues/656
return
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/660
return
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
learn = get_learn_function(learn_fn)
kwargs = {}
kwargs.update(network_kwargs[network_fn])
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
make_session(make_default=True, graph=tf.Graph());
model1 = learn(seed=1)
make_session(make_default=True, graph=tf.Graph());
model2 = learn(seed=2)
model1.step(env.observation_space.sample())
model2.step(env.observation_space.sample())
def _serialize_variables():
sess = get_session()

View File

@@ -2,4 +2,4 @@
- Original paper: https://arxiv.org/abs/1509.02971
- Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/
- `python -m baselines.ddpg.main` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options.
- `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options.

View File

@@ -47,6 +47,9 @@ class ActWrapper(object):
return self._act(*args, **kwargs)
def step(self, observation, **kwargs):
# DQN doesn't use RNNs so we ignore states and masks
kwargs.pop('S', None)
kwargs.pop('M', None)
return self._act([observation], **kwargs), None, None, None
def save_act(self, path=None):

View File

@@ -121,9 +121,11 @@ def build_env(args):
env = retro_wrappers.wrap_deepmind_retro(env)
else:
get_session(tf.ConfigProto(allow_soft_placement=True,
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1))
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
get_session(config=config)
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)

View File

@@ -10,7 +10,8 @@ if sys.version_info.major != 3:
extras = {
'test': [
'filelock',
'pytest'
'pytest',
'atari-py'
],
'bullet': [
'pybullet',
@@ -27,7 +28,7 @@ setup(name='baselines',
packages=[package for package in find_packages()
if package.startswith('baselines')],
install_requires=[
'gym[mujoco,atari,classic_control,robotics]',
'gym',
'scipy',
'tqdm',
'joblib',