Merge branch 'master' of github.com:openai/baselines into internal

This commit is contained in:
Peter Zhokhov
2018-10-19 09:52:23 -07:00
7 changed files with 52 additions and 10 deletions

View File

@@ -53,7 +53,7 @@ class PolicyWithValue(object):
# Calculate the neg log of our probability # Calculate the neg log of our probability
self.neglogp = self.pd.neglogp(self.action) self.neglogp = self.pd.neglogp(self.action)
self.sess = sess self.sess = sess or tf.get_default_session()
if estimate_q: if estimate_q:
assert isinstance(env.action_space, gym.spaces.Discrete) assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -64,7 +64,7 @@ class PolicyWithValue(object):
self.vf = self.vf[:,0] self.vf = self.vf[:,0]
def _evaluate(self, variables, observation, **extra_feed): def _evaluate(self, variables, observation, **extra_feed):
sess = self.sess or tf.get_default_session() sess = self.sess
feed_dict = {self.X: adjust_shape(self.X, observation)} feed_dict = {self.X: adjust_shape(self.X, observation)}
for inpt_name, data in extra_feed.items(): for inpt_name, data in extra_feed.items():
if inpt_name in self.__dict__.keys(): if inpt_name in self.__dict__.keys():

View File

@@ -1,7 +1,6 @@
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import tempfile import tempfile
import filelock
from gym import Env from gym import Env
from gym.spaces import Discrete, Box from gym.spaces import Discrete, Box
@@ -14,6 +13,7 @@ class MnistEnv(Env):
episode_len=None, episode_len=None,
no_images=None no_images=None
): ):
import filelock
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
# we could use temporary directory for this with a context manager and # we could use temporary directory for this with a context manager and
# TemporaryDirecotry, but then each test that uses mnist would re-download the data # TemporaryDirecotry, but then each test that uses mnist would re-download the data

View File

@@ -1,4 +1,5 @@
import os import os
import gym
import tempfile import tempfile
import pytest import pytest
import tensorflow as tf import tensorflow as tf
@@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn):
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies # TODO make acktr work with recurrent policies
# and test # and test
# github issue: https://github.com/openai/baselines/issues/194 # github issue: https://github.com/openai/baselines/issues/660
return return
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)]) env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
@@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn):
np.testing.assert_allclose(std1, std2, atol=0.5) np.testing.assert_allclose(std1, std2, atol=0.5)
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
@pytest.mark.parametrize("network_fn", ['mlp'])
def test_coexistence(learn_fn, network_fn):
'''
Test if more than one model can exist at a time
'''
if learn_fn == 'deepq':
# TODO enable multiple DQN models to be useable at the same time
# github issue https://github.com/openai/baselines/issues/656
return
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies
# and test
# github issue: https://github.com/openai/baselines/issues/660
return
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
learn = get_learn_function(learn_fn)
kwargs = {}
kwargs.update(network_kwargs[network_fn])
kwargs.update(learn_kwargs[learn_fn])
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
make_session(make_default=True, graph=tf.Graph());
model1 = learn(seed=1)
make_session(make_default=True, graph=tf.Graph());
model2 = learn(seed=2)
model1.step(env.observation_space.sample())
model2.step(env.observation_space.sample())
def _serialize_variables(): def _serialize_variables():
sess = get_session() sess = get_session()

View File

@@ -2,4 +2,4 @@
- Original paper: https://arxiv.org/abs/1509.02971 - Original paper: https://arxiv.org/abs/1509.02971
- Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/ - Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/
- `python -m baselines.ddpg.main` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options. - `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options.

View File

@@ -47,6 +47,9 @@ class ActWrapper(object):
return self._act(*args, **kwargs) return self._act(*args, **kwargs)
def step(self, observation, **kwargs): def step(self, observation, **kwargs):
# DQN doesn't use RNNs so we ignore states and masks
kwargs.pop('S', None)
kwargs.pop('M', None)
return self._act([observation], **kwargs), None, None, None return self._act([observation], **kwargs), None, None, None
def save_act(self, path=None): def save_act(self, path=None):

View File

@@ -121,9 +121,11 @@ def build_env(args):
env = retro_wrappers.wrap_deepmind_retro(env) env = retro_wrappers.wrap_deepmind_retro(env)
else: else:
get_session(tf.ConfigProto(allow_soft_placement=True, config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1, intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)) inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
get_session(config=config)
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale) env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)

View File

@@ -10,7 +10,8 @@ if sys.version_info.major != 3:
extras = { extras = {
'test': [ 'test': [
'filelock', 'filelock',
'pytest' 'pytest',
'atari-py'
], ],
'bullet': [ 'bullet': [
'pybullet', 'pybullet',
@@ -27,7 +28,7 @@ setup(name='baselines',
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('baselines')], if package.startswith('baselines')],
install_requires=[ install_requires=[
'gym[mujoco,atari,classic_control,robotics]', 'gym',
'scipy', 'scipy',
'tqdm', 'tqdm',
'joblib', 'joblib',