Merge branch 'master' of github.com:openai/baselines into internal
This commit is contained in:
@@ -53,7 +53,7 @@ class PolicyWithValue(object):
|
|||||||
|
|
||||||
# Calculate the neg log of our probability
|
# Calculate the neg log of our probability
|
||||||
self.neglogp = self.pd.neglogp(self.action)
|
self.neglogp = self.pd.neglogp(self.action)
|
||||||
self.sess = sess
|
self.sess = sess or tf.get_default_session()
|
||||||
|
|
||||||
if estimate_q:
|
if estimate_q:
|
||||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||||
@@ -64,7 +64,7 @@ class PolicyWithValue(object):
|
|||||||
self.vf = self.vf[:,0]
|
self.vf = self.vf[:,0]
|
||||||
|
|
||||||
def _evaluate(self, variables, observation, **extra_feed):
|
def _evaluate(self, variables, observation, **extra_feed):
|
||||||
sess = self.sess or tf.get_default_session()
|
sess = self.sess
|
||||||
feed_dict = {self.X: adjust_shape(self.X, observation)}
|
feed_dict = {self.X: adjust_shape(self.X, observation)}
|
||||||
for inpt_name, data in extra_feed.items():
|
for inpt_name, data in extra_feed.items():
|
||||||
if inpt_name in self.__dict__.keys():
|
if inpt_name in self.__dict__.keys():
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tempfile
|
import tempfile
|
||||||
import filelock
|
|
||||||
from gym import Env
|
from gym import Env
|
||||||
from gym.spaces import Discrete, Box
|
from gym.spaces import Discrete, Box
|
||||||
|
|
||||||
@@ -14,6 +13,7 @@ class MnistEnv(Env):
|
|||||||
episode_len=None,
|
episode_len=None,
|
||||||
no_images=None
|
no_images=None
|
||||||
):
|
):
|
||||||
|
import filelock
|
||||||
from tensorflow.examples.tutorials.mnist import input_data
|
from tensorflow.examples.tutorials.mnist import input_data
|
||||||
# we could use temporary directory for this with a context manager and
|
# we could use temporary directory for this with a context manager and
|
||||||
# TemporaryDirecotry, but then each test that uses mnist would re-download the data
|
# TemporaryDirecotry, but then each test that uses mnist would re-download the data
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import gym
|
||||||
import tempfile
|
import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -39,7 +40,7 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||||
# TODO make acktr work with recurrent policies
|
# TODO make acktr work with recurrent policies
|
||||||
# and test
|
# and test
|
||||||
# github issue: https://github.com/openai/baselines/issues/194
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
return
|
return
|
||||||
|
|
||||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
||||||
@@ -75,6 +76,41 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
np.testing.assert_allclose(std1, std2, atol=0.5)
|
np.testing.assert_allclose(std1, std2, atol=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
|
||||||
|
@pytest.mark.parametrize("network_fn", ['mlp'])
|
||||||
|
def test_coexistence(learn_fn, network_fn):
|
||||||
|
'''
|
||||||
|
Test if more than one model can exist at a time
|
||||||
|
'''
|
||||||
|
|
||||||
|
if learn_fn == 'deepq':
|
||||||
|
# TODO enable multiple DQN models to be useable at the same time
|
||||||
|
# github issue https://github.com/openai/baselines/issues/656
|
||||||
|
return
|
||||||
|
|
||||||
|
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
||||||
|
# TODO make acktr work with recurrent policies
|
||||||
|
# and test
|
||||||
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
|
return
|
||||||
|
|
||||||
|
env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
|
||||||
|
learn = get_learn_function(learn_fn)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
kwargs.update(network_kwargs[network_fn])
|
||||||
|
kwargs.update(learn_kwargs[learn_fn])
|
||||||
|
|
||||||
|
learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
|
||||||
|
make_session(make_default=True, graph=tf.Graph());
|
||||||
|
model1 = learn(seed=1)
|
||||||
|
make_session(make_default=True, graph=tf.Graph());
|
||||||
|
model2 = learn(seed=2)
|
||||||
|
|
||||||
|
model1.step(env.observation_space.sample())
|
||||||
|
model2.step(env.observation_space.sample())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_variables():
|
def _serialize_variables():
|
||||||
sess = get_session()
|
sess = get_session()
|
||||||
|
@@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
- Original paper: https://arxiv.org/abs/1509.02971
|
- Original paper: https://arxiv.org/abs/1509.02971
|
||||||
- Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/
|
- Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/
|
||||||
- `python -m baselines.ddpg.main` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options.
|
- `python -m baselines.run --alg=ddpg --env=HalfCheetah-v2 --num_timesteps=1e6` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options.
|
||||||
|
@@ -47,6 +47,9 @@ class ActWrapper(object):
|
|||||||
return self._act(*args, **kwargs)
|
return self._act(*args, **kwargs)
|
||||||
|
|
||||||
def step(self, observation, **kwargs):
|
def step(self, observation, **kwargs):
|
||||||
|
# DQN doesn't use RNNs so we ignore states and masks
|
||||||
|
kwargs.pop('S', None)
|
||||||
|
kwargs.pop('M', None)
|
||||||
return self._act([observation], **kwargs), None, None, None
|
return self._act([observation], **kwargs), None, None, None
|
||||||
|
|
||||||
def save_act(self, path=None):
|
def save_act(self, path=None):
|
||||||
|
@@ -121,9 +121,11 @@ def build_env(args):
|
|||||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
get_session(tf.ConfigProto(allow_soft_placement=True,
|
config = tf.ConfigProto(allow_soft_placement=True,
|
||||||
intra_op_parallelism_threads=1,
|
intra_op_parallelism_threads=1,
|
||||||
inter_op_parallelism_threads=1))
|
inter_op_parallelism_threads=1)
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
get_session(config=config)
|
||||||
|
|
||||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
|
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
|
||||||
|
|
||||||
|
5
setup.py
5
setup.py
@@ -10,7 +10,8 @@ if sys.version_info.major != 3:
|
|||||||
extras = {
|
extras = {
|
||||||
'test': [
|
'test': [
|
||||||
'filelock',
|
'filelock',
|
||||||
'pytest'
|
'pytest',
|
||||||
|
'atari-py'
|
||||||
],
|
],
|
||||||
'bullet': [
|
'bullet': [
|
||||||
'pybullet',
|
'pybullet',
|
||||||
@@ -27,7 +28,7 @@ setup(name='baselines',
|
|||||||
packages=[package for package in find_packages()
|
packages=[package for package in find_packages()
|
||||||
if package.startswith('baselines')],
|
if package.startswith('baselines')],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'gym[mujoco,atari,classic_control,robotics]',
|
'gym',
|
||||||
'scipy',
|
'scipy',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'joblib',
|
'joblib',
|
||||||
|
Reference in New Issue
Block a user