diff --git a/README.md b/README.md index b8214ee..197f01a 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,8 @@ Install baselines package ```bash pip install -e . ``` +### MuJoCo +Some of the baselines examples use [MuJoCo](http://www.mujoco.org) (multi-joint dynamics in contact) physics simulator, which is proprietary and requires binaries and a license (temporary 30-day license can be obtained from [www.mujoco.org](http://www.mujoco.org)). Instructions on setting up MuJoCo can be found [here](https://github.com/openai/mujoco-py) ## Testing the installation All unit tests in baselines can be run using pytest runner: diff --git a/baselines/a2c/a2c.py b/baselines/a2c/a2c.py index c28fe65..f1de88a 100644 --- a/baselines/a2c/a2c.py +++ b/baselines/a2c/a2c.py @@ -131,7 +131,6 @@ class Runner(AbstractEnvRunner): return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=100): - tf.reset_default_graph() set_global_seeds(seed) nenvs = env.num_envs @@ -158,3 +157,4 @@ def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, e logger.record_tabular("explained_variance", float(ev)) logger.dump_tabular() env.close() + return model diff --git a/baselines/a2c/policies.py b/baselines/a2c/policies.py index 172a3ee..6fbbb14 100644 --- a/baselines/a2c/policies.py +++ b/baselines/a2c/policies.py @@ -2,6 +2,7 @@ import numpy as np import tensorflow as tf from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm from baselines.common.distributions import make_pdtype +from baselines.common.input import observation_input def nature_cnn(unscaled_images, **conv_kwargs): """ @@ -19,14 +20,12 @@ def nature_cnn(unscaled_images, **conv_kwargs): class LnLstmPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): nenv = nbatch // nsteps - nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc) - X = tf.placeholder(tf.uint8, ob_shape) #obs + X, processed_x = observation_input(ob_space, nbatch) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states self.pdtype = make_pdtype(ac_space) with tf.variable_scope("model", reuse=reuse): - h = nature_cnn(X) + h = nature_cnn(processed_x) xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) @@ -56,11 +55,9 @@ class LstmPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): nenv = nbatch // nsteps - - nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc) self.pdtype = make_pdtype(ac_space) - X = tf.placeholder(tf.uint8, ob_shape) #obs + X, processed_x = observation_input(ob_space, nbatch) + M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states with tf.variable_scope("model", reuse=reuse): @@ -93,12 +90,10 @@ class LstmPolicy(object): class CnnPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613 - nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc) self.pdtype = make_pdtype(ac_space) - X = tf.placeholder(tf.uint8, ob_shape) #obs + X, processed_x = observation_input(ob_space, nbatch) with tf.variable_scope("model", reuse=reuse): - h = nature_cnn(X, **conv_kwargs) + h = nature_cnn(processed_x, **conv_kwargs) vf = fc(h, 'v', 1)[:,0] self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01) @@ -120,15 +115,14 @@ class CnnPolicy(object): class MlpPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 - ob_shape = (nbatch,) + ob_space.shape self.pdtype = make_pdtype(ac_space) - X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs with tf.variable_scope("model", reuse=reuse): + X, processed_x = observation_input(ob_space, nbatch) activ = tf.tanh - flatten = tf.layers.flatten - pi_h1 = activ(fc(flatten(X), 'pi_fc1', nh=64, init_scale=np.sqrt(2))) + processed_x = tf.layers.flatten(processed_x) + pi_h1 = activ(fc(processed_x, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) - vf_h1 = activ(fc(flatten(X), 'vf_fc1', nh=64, init_scale=np.sqrt(2))) + vf_h1 = activ(fc(processed_x, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) vf = fc(vf_h2, 'vf', 1)[:,0] diff --git a/baselines/bench/benchmarks.py b/baselines/bench/benchmarks.py index 69454a9..a5a35f8 100644 --- a/baselines/bench/benchmarks.py +++ b/baselines/bench/benchmarks.py @@ -9,6 +9,8 @@ _atariexpl7 = ['Freeway', 'Gravitar', 'MontezumaRevenge', 'Pitfall', 'PrivateEye _BENCHMARKS = [] remove_version_re = re.compile(r'-v\d+$') + + def register_benchmark(benchmark): for b in _BENCHMARKS: if b['name'] == benchmark['name']: @@ -138,3 +140,11 @@ register_benchmark({ 'tasks': [{'desc': _game, 'env_id': _game + _ATARI_SUFFIX, 'trials': 2, 'num_timesteps': int(10e6)} for _game in _atari50] }) +# HER DDPG + +register_benchmark({ + 'name': 'HerDdpg', + 'description': 'Smoke-test only benchmark of HER', + 'tasks': [{'trials': 1, 'env_id': 'FetchReach-v1'}] +}) + diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index b88d529..64df993 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -74,6 +74,7 @@ def mujoco_arg_parser(): parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) + parser.add_argument('--play', default=False, action='store_true') return parser def robotics_arg_parser(): diff --git a/baselines/common/identity_env.py b/baselines/common/identity_env.py new file mode 100644 index 0000000..f07cd5b --- /dev/null +++ b/baselines/common/identity_env.py @@ -0,0 +1,30 @@ +from gym import Env +from gym.spaces import Discrete + + +class IdentityEnv(Env): + def __init__( + self, + dim, + ep_length=100, + ): + + self.action_space = Discrete(dim) + self.reset() + + def reset(self): + self._choose_next_state() + self.observation_space = self.action_space + + return self.state + + def step(self, actions): + rew = self._get_reward(actions) + self._choose_next_state() + return self.state, rew, False, {} + + def _choose_next_state(self): + self.state = self.action_space.sample() + + def _get_reward(self, actions): + return 1 if self.state == actions else 0 diff --git a/baselines/common/input.py b/baselines/common/input.py new file mode 100644 index 0000000..7fbf9fc --- /dev/null +++ b/baselines/common/input.py @@ -0,0 +1,30 @@ +import tensorflow as tf +from gym.spaces import Discrete, Box + +def observation_input(ob_space, batch_size=None, name='Ob'): + ''' + Build observation input with encoding depending on the + observation space type + Params: + + ob_space: observation space (should be one of gym.spaces) + batch_size: batch size for input (default is None, so that resulting input placeholder can take tensors with any batch size) + name: tensorflow variable name for input placeholder + + returns: tuple (input_placeholder, processed_input_tensor) + ''' + if isinstance(ob_space, Discrete): + input_x = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name) + processed_x = tf.to_float(tf.one_hot(input_x, ob_space.n)) + return input_x, processed_x + + elif isinstance(ob_space, Box): + input_shape = (batch_size,) + ob_space.shape + input_x = tf.placeholder(shape=input_shape, dtype=ob_space.dtype, name=name) + processed_x = tf.to_float(input_x) + return input_x, processed_x + + else: + raise NotImplementedError + + diff --git a/baselines/common/runners.py b/baselines/common/runners.py index 33b4365..0a4b221 100644 --- a/baselines/common/runners.py +++ b/baselines/common/runners.py @@ -7,7 +7,7 @@ class AbstractEnvRunner(ABC): self.model = model nenv = env.num_envs self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape - self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=model.train_model.X.dtype.name) + self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name) self.obs[:] = env.reset() self.nsteps = nsteps self.states = model.initial_state diff --git a/baselines/common/test_identity.py b/baselines/common/test_identity.py new file mode 100644 index 0000000..a429e0c --- /dev/null +++ b/baselines/common/test_identity.py @@ -0,0 +1,44 @@ +import pytest +import tensorflow as tf +import random +import numpy as np +from gym.spaces import np_random + +from baselines.a2c import a2c +from baselines.ppo2 import ppo2 +from baselines.common.identity_env import IdentityEnv +from baselines.common.vec_env.dummy_vec_env import DummyVecEnv +from baselines.ppo2.policies import MlpPolicy + + +learn_func_list = [ + lambda e: a2c.learn(policy=MlpPolicy, env=e, seed=0, total_timesteps=50000), + lambda e: ppo2.learn(policy=MlpPolicy, env=e, total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.01) +] + + +@pytest.mark.slow +@pytest.mark.parametrize("learn_func", learn_func_list) +def test_identity(learn_func): + ''' + Test if the algorithm (with a given policy) + can learn an identity transformation (i.e. return observation as an action) + ''' + np.random.seed(0) + np_random.seed(0) + random.seed(0) + + env = DummyVecEnv([lambda: IdentityEnv(10)]) + + with tf.Graph().as_default(), tf.Session().as_default(): + tf.set_random_seed(0) + model = learn_func(env) + + N_TRIALS = 1000 + sum_rew = 0 + obs = env.reset() + for i in range(N_TRIALS): + obs, rew, done, _ = env.step(model.step(obs)[0]) + sum_rew += rew + + assert sum_rew > 0.9 * N_TRIALS diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 9b2822b..fbc9fae 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -279,3 +279,27 @@ def display_var_info(vars): logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape))) logger.info("Total model parameters: %0.2f million" % (count_params*1e-6)) + + +def get_available_gpus(): + # recipe from here: + # https://stackoverflow.com/questions/38559755/how-to-get-current-available-gpus-in-tensorflow?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa + + from tensorflow.python.client import device_lib + local_device_protos = device_lib.list_local_devices() + return [x.name for x in local_device_protos if x.device_type == 'GPU'] + +# ================================================================ +# Saving variables +# ================================================================ + +def load_state(fname): + saver = tf.train.Saver() + saver.restore(tf.get_default_session(), fname) + +def save_state(fname): + os.makedirs(os.path.dirname(fname), exist_ok=True) + saver = tf.train.Saver() + saver.save(tf.get_default_session(), fname) + + diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index f966d29..d5851e1 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -11,18 +11,18 @@ class DummyVecEnv(VecEnv): shapes, dtypes = {}, {} self.keys = [] obs_space = env.observation_space + if isinstance(obs_space, spaces.Dict): assert isinstance(obs_space.spaces, OrderedDict) - for key, box in obs_space.spaces.items(): - assert isinstance(box, spaces.Box) - shapes[key] = box.shape - dtypes[key] = box.dtype - self.keys.append(key) + subspaces = obs_space.spaces else: - box = obs_space - assert isinstance(box, spaces.Box) - self.keys = [None] - shapes, dtypes = { None: box.shape }, { None: box.dtype } + subspaces = {None: obs_space} + + for key, box in subspaces.items(): + shapes[key] = box.shape + dtypes[key] = box.dtype + self.keys.append(key) + self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) @@ -50,6 +50,9 @@ class DummyVecEnv(VecEnv): def close(self): return + def render(self): + return [e.render() for e in self.envs] + def _save_obs(self, e, obs): for k in self.keys: if k is None: diff --git a/baselines/deepq/experiments/custom_cartpole.py b/baselines/deepq/experiments/custom_cartpole.py index 4896904..b5a381a 100644 --- a/baselines/deepq/experiments/custom_cartpole.py +++ b/baselines/deepq/experiments/custom_cartpole.py @@ -9,7 +9,7 @@ import baselines.common.tf_util as U from baselines import logger from baselines import deepq from baselines.deepq.replay_buffer import ReplayBuffer -from baselines.deepq.utils import BatchInput +from baselines.deepq.utils import ObservationInput from baselines.common.schedules import LinearSchedule @@ -28,7 +28,7 @@ if __name__ == '__main__': env = gym.make("CartPole-v0") # Create all the functions necessary to train the model act, train, update_target, debug = deepq.build_train( - make_obs_ph=lambda name: BatchInput(env.observation_space.shape, name=name), + make_obs_ph=lambda name: ObservationInput(env.observation_space, name=name), q_func=model, num_actions=env.action_space.n, optimizer=tf.train.AdamOptimizer(learning_rate=5e-4), diff --git a/baselines/deepq/experiments/run_atari.py b/baselines/deepq/experiments/run_atari.py index 9836268..b6b427b 100644 --- a/baselines/deepq/experiments/run_atari.py +++ b/baselines/deepq/experiments/run_atari.py @@ -14,6 +14,9 @@ def main(): parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) parser.add_argument('--dueling', type=int, default=1) parser.add_argument('--num-timesteps', type=int, default=int(10e6)) + parser.add_argument('--checkpoint-freq', type=int, default=10000) + parser.add_argument('--checkpoint-path', type=str, default=None) + args = parser.parse_args() logger.configure() set_global_seeds(args.seed) @@ -39,7 +42,9 @@ def main(): target_network_update_freq=1000, gamma=0.99, prioritized_replay=bool(args.prioritized), - prioritized_replay_alpha=args.prioritized_replay_alpha + prioritized_replay_alpha=args.prioritized_replay_alpha, + checkpoint_freq=args.checkpoint_freq, + checkpoint_path=args.checkpoint_path, ) env.close() diff --git a/baselines/deepq/simple.py b/baselines/deepq/simple.py index f96cdc6..4bad145 100644 --- a/baselines/deepq/simple.py +++ b/baselines/deepq/simple.py @@ -6,13 +6,15 @@ import zipfile import cloudpickle import numpy as np -import gym import baselines.common.tf_util as U +from baselines.common.tf_util import load_state, save_state from baselines import logger from baselines.common.schedules import LinearSchedule +from baselines.common.input import observation_input + from baselines import deepq from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer -from baselines.deepq.utils import BatchInput, load_state, save_state +from baselines.deepq.utils import ObservationInput class ActWrapper(object): @@ -88,6 +90,7 @@ def learn(env, batch_size=32, print_freq=100, checkpoint_freq=10000, + checkpoint_path=None, learning_starts=1000, gamma=1.0, target_network_update_freq=500, @@ -170,9 +173,9 @@ def learn(env, # capture the shape outside the closure so that the env object is not serialized # by cloudpickle when serializing make_obs_ph - observation_space_shape = env.observation_space.shape + def make_obs_ph(name): - return BatchInput(observation_space_shape, name=name) + return ObservationInput(env.observation_space, name=name) act, train, update_target, debug = deepq.build_train( make_obs_ph=make_obs_ph, @@ -216,9 +219,17 @@ def learn(env, saved_mean_reward = None obs = env.reset() reset = True + with tempfile.TemporaryDirectory() as td: - model_saved = False + td = checkpoint_path or td + model_file = os.path.join(td, "model") + model_saved = False + if tf.train.latest_checkpoint(td) is not None: + load_state(model_file) + logger.log('Loaded model from {}'.format(model_file)) + model_saved = True + for t in range(max_timesteps): if callback is not None: if callback(locals(), globals()): diff --git a/baselines/deepq/test_identity.py b/baselines/deepq/test_identity.py new file mode 100644 index 0000000..ef57e70 --- /dev/null +++ b/baselines/deepq/test_identity.py @@ -0,0 +1,43 @@ +import tensorflow as tf +import random + +from baselines import deepq +from baselines.common.identity_env import IdentityEnv + + +def test_identity(): + + with tf.Graph().as_default(): + env = IdentityEnv(10) + random.seed(0) + + tf.set_random_seed(0) + + param_noise = False + model = deepq.models.mlp([32]) + act = deepq.learn( + env, + q_func=model, + lr=1e-3, + max_timesteps=10000, + buffer_size=50000, + exploration_fraction=0.1, + exploration_final_eps=0.02, + print_freq=10, + param_noise=param_noise, + ) + + tf.set_random_seed(0) + + N_TRIALS = 1000 + sum_rew = 0 + obs = env.reset() + for i in range(N_TRIALS): + obs, rew, done, _ = env.step(act([obs])) + sum_rew += rew + + assert sum_rew > 0.9 * N_TRIALS + + +if __name__ == '__main__': + test_identity() diff --git a/baselines/deepq/utils.py b/baselines/deepq/utils.py index 4f0e0c3..90b932e 100644 --- a/baselines/deepq/utils.py +++ b/baselines/deepq/utils.py @@ -1,24 +1,12 @@ -import os +from baselines.common.input import observation_input import tensorflow as tf -# ================================================================ -# Saving variables -# ================================================================ - -def load_state(fname): - saver = tf.train.Saver() - saver.restore(tf.get_default_session(), fname) - -def save_state(fname): - os.makedirs(os.path.dirname(fname), exist_ok=True) - saver = tf.train.Saver() - saver.save(tf.get_default_session(), fname) - # ================================================================ # Placeholders # ================================================================ + class TfInput(object): def __init__(self, name="(unnamed)"): """Generalized Tensorflow placeholder. The main differences are: @@ -50,20 +38,6 @@ class PlaceholderTfInput(TfInput): def make_feed_dict(self, data): return {self._placeholder: data} -class BatchInput(PlaceholderTfInput): - def __init__(self, shape, dtype=tf.float32, name=None): - """Creates a placeholder for a batch of tensors of a given shape and dtype - - Parameters - ---------- - shape: [int] - shape of a single elemenet of the batch - dtype: tf.dtype - number representation used for tensor contents - name: str - name of the underlying placeholder - """ - super().__init__(tf.placeholder(dtype, [None] + list(shape), name=name)) class Uint8Input(PlaceholderTfInput): def __init__(self, shape, name=None): @@ -85,4 +59,25 @@ class Uint8Input(PlaceholderTfInput): self._output = tf.cast(super().get(), tf.float32) / 255.0 def get(self): - return self._output \ No newline at end of file + return self._output + + +class ObservationInput(PlaceholderTfInput): + def __init__(self, observation_space, name=None): + """Creates an input placeholder tailored to a specific observation space + + Parameters + ---------- + + observation_space: + observation space of the environment. Should be one of the gym.spaces types + name: str + tensorflow name of the underlying placeholder + """ + inpt, self.processed_inpt = observation_input(observation_space, name=name) + super().__init__(inpt) + + def get(self): + return self.processed_inpt + + diff --git a/baselines/her/experiment/config.py b/baselines/her/experiment/config.py index d64211b..cf29ca5 100644 --- a/baselines/her/experiment/config.py +++ b/baselines/her/experiment/config.py @@ -1,7 +1,4 @@ -from copy import deepcopy import numpy as np -import json -import os import gym from baselines import logger @@ -10,7 +7,7 @@ from baselines.her.her import make_sample_her_transitions DEFAULT_ENV_PARAMS = { - 'FetchReach-v0': { + 'FetchReach-v1': { 'n_cycles': 10, }, } @@ -51,6 +48,8 @@ DEFAULT_PARAMS = { CACHED_ENVS = {} + + def cached_make_env(make_env): """ Only creates a new environment from the provided function if one has not yet already been @@ -68,6 +67,7 @@ def prepare_params(kwargs): ddpg_params = dict() env_name = kwargs['env_name'] + def make_env(): return gym.make(env_name) kwargs['make_env'] = make_env @@ -75,7 +75,7 @@ def prepare_params(kwargs): assert hasattr(tmp_env, '_max_episode_steps') kwargs['T'] = tmp_env._max_episode_steps tmp_env.reset() - kwargs['max_u'] = np.array(kwargs['max_u']) if type(kwargs['max_u']) == list else kwargs['max_u'] + kwargs['max_u'] = np.array(kwargs['max_u']) if isinstance(kwargs['max_u'], list) else kwargs['max_u'] kwargs['gamma'] = 1. - 1. / kwargs['T'] if 'lr' in kwargs: kwargs['pi_lr'] = kwargs['lr'] @@ -83,7 +83,7 @@ def prepare_params(kwargs): del kwargs['lr'] for name in ['buffer_size', 'hidden', 'layers', 'network_class', - 'polyak', + 'polyak', 'batch_size', 'Q_lr', 'pi_lr', 'norm_eps', 'norm_clip', 'max_u', 'action_l2', 'clip_obs', 'scope', 'relative_goals']: @@ -103,6 +103,7 @@ def log_params(params, logger=logger): def configure_her(params): env = cached_make_env(params['make_env']) env.reset() + def reward_fun(ag_2, g, info): # vectorized return env.compute_reward(achieved_goal=ag_2, desired_goal=g, info=info) diff --git a/baselines/her/experiment/train.py b/baselines/her/experiment/train.py index 60d8d1a..aeaf1c5 100644 --- a/baselines/her/experiment/train.py +++ b/baselines/her/experiment/train.py @@ -13,6 +13,8 @@ import baselines.her.experiment.config as config from baselines.her.rollout import RolloutWorker from baselines.her.util import mpi_fork +from subprocess import CalledProcessError + def mpi_average(value): if value == []: @@ -81,12 +83,17 @@ def train(policy, rollout_worker, evaluator, def launch( - env_name, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return, + env, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return, override_params={}, save_policies=True ): # Fork for multi-CPU MPI implementation. if num_cpu > 1: - whoami = mpi_fork(num_cpu) + try: + whoami = mpi_fork(num_cpu, ['--bind-to', 'core']) + except CalledProcessError: + # fancy version of mpi call failed, try simple version + whoami = mpi_fork(num_cpu) + if whoami == 'parent': sys.exit(0) import baselines.common.tf_util as U @@ -109,10 +116,10 @@ def launch( # Prepare params. params = config.DEFAULT_PARAMS - params['env_name'] = env_name + params['env_name'] = env params['replay_strategy'] = replay_strategy - if env_name in config.DEFAULT_ENV_PARAMS: - params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in + if env in config.DEFAULT_ENV_PARAMS: + params.update(config.DEFAULT_ENV_PARAMS[env]) # merge env-specific parameters in params.update(**override_params) # makes it possible to override any parameter with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f: json.dump(params, f) @@ -126,7 +133,7 @@ def launch( 'You are running HER with just a single MPI worker. This will work, but the ' + 'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' + 'were obtained with --num_cpu 19. This makes a significant difference and if you ' + - 'are looking to reproduce those results, be aware of this. Please also refer to ' + + 'are looking to reproduce those results, be aware of this. Please also refer to ' + 'https://github.com/openai/baselines/issues/314 for further details.') logger.warn('****************') logger.warn() @@ -168,7 +175,7 @@ def launch( @click.command() -@click.option('--env_name', type=str, default='FetchReach-v0', help='the name of the OpenAI Gym environment that you want to train on') +@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on') @click.option('--logdir', type=str, default=None, help='the path to where logs and policy pickles should go. If not specified, creates a folder in /tmp/') @click.option('--n_epochs', type=int, default=50, help='the number of training epochs to run') @click.option('--num_cpu', type=int, default=1, help='the number of CPU cores to use (using MPI)') diff --git a/baselines/her/util.py b/baselines/her/util.py index d79a776..d637aa6 100644 --- a/baselines/her/util.py +++ b/baselines/her/util.py @@ -58,12 +58,12 @@ def nn(input, layers_sizes, reuse=None, flatten=False, name=""): """Creates a simple neural network """ for i, size in enumerate(layers_sizes): - activation = tf.nn.relu if i < len(layers_sizes)-1 else None + activation = tf.nn.relu if i < len(layers_sizes) - 1 else None input = tf.layers.dense(inputs=input, units=size, kernel_initializer=tf.contrib.layers.xavier_initializer(), reuse=reuse, - name=name+'_'+str(i)) + name=name + '_' + str(i)) if activation: input = activation(input) if flatten: @@ -85,7 +85,7 @@ def install_mpi_excepthook(): sys.excepthook = new_hook -def mpi_fork(n): +def mpi_fork(n, extra_mpi_args=[]): """Re-launches the current script with workers Returns "parent" for original parent, "child" for MPI children """ @@ -99,14 +99,10 @@ def mpi_fork(n): IN_MPI="1" ) # "-bind-to core" is crucial for good performance - args = [ - "mpirun", - "-np", - str(n), - "-bind-to", - "core", - sys.executable - ] + args = ["mpirun", "-np", str(n)] + \ + extra_mpi_args + \ + [sys.executable] + args += sys.argv subprocess.check_call(args, env=env) return "parent" @@ -140,5 +136,5 @@ def reshape_for_broadcasting(source, target): before broadcasting it with MPI. """ dim = len(target.get_shape()) - shape = ([1] * (dim-1)) + [-1] + shape = ([1] * (dim - 1)) + [-1] return tf.reshape(tf.cast(source, target.dtype), shape) diff --git a/baselines/ppo2/policies.py b/baselines/ppo2/policies.py index 172a3ee..6fbbb14 100644 --- a/baselines/ppo2/policies.py +++ b/baselines/ppo2/policies.py @@ -2,6 +2,7 @@ import numpy as np import tensorflow as tf from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm from baselines.common.distributions import make_pdtype +from baselines.common.input import observation_input def nature_cnn(unscaled_images, **conv_kwargs): """ @@ -19,14 +20,12 @@ def nature_cnn(unscaled_images, **conv_kwargs): class LnLstmPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): nenv = nbatch // nsteps - nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc) - X = tf.placeholder(tf.uint8, ob_shape) #obs + X, processed_x = observation_input(ob_space, nbatch) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states self.pdtype = make_pdtype(ac_space) with tf.variable_scope("model", reuse=reuse): - h = nature_cnn(X) + h = nature_cnn(processed_x) xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) @@ -56,11 +55,9 @@ class LstmPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): nenv = nbatch // nsteps - - nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc) self.pdtype = make_pdtype(ac_space) - X = tf.placeholder(tf.uint8, ob_shape) #obs + X, processed_x = observation_input(ob_space, nbatch) + M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states with tf.variable_scope("model", reuse=reuse): @@ -93,12 +90,10 @@ class LstmPolicy(object): class CnnPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613 - nh, nw, nc = ob_space.shape - ob_shape = (nbatch, nh, nw, nc) self.pdtype = make_pdtype(ac_space) - X = tf.placeholder(tf.uint8, ob_shape) #obs + X, processed_x = observation_input(ob_space, nbatch) with tf.variable_scope("model", reuse=reuse): - h = nature_cnn(X, **conv_kwargs) + h = nature_cnn(processed_x, **conv_kwargs) vf = fc(h, 'v', 1)[:,0] self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01) @@ -120,15 +115,14 @@ class CnnPolicy(object): class MlpPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 - ob_shape = (nbatch,) + ob_space.shape self.pdtype = make_pdtype(ac_space) - X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs with tf.variable_scope("model", reuse=reuse): + X, processed_x = observation_input(ob_space, nbatch) activ = tf.tanh - flatten = tf.layers.flatten - pi_h1 = activ(fc(flatten(X), 'pi_fc1', nh=64, init_scale=np.sqrt(2))) + processed_x = tf.layers.flatten(processed_x) + pi_h1 = activ(fc(processed_x, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) - vf_h1 = activ(fc(flatten(X), 'vf_fc1', nh=64, init_scale=np.sqrt(2))) + vf_h1 = activ(fc(processed_x, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) vf = fc(vf_h2, 'vf', 1)[:,0] diff --git a/baselines/ppo2/ppo2.py b/baselines/ppo2/ppo2.py index 6acdae0..fd34f52 100644 --- a/baselines/ppo2/ppo2.py +++ b/baselines/ppo2/ppo2.py @@ -236,6 +236,7 @@ def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr, print('Saving to', savepath) model.save(savepath) env.close() + return model def safemean(xs): return np.nan if len(xs) == 0 else np.mean(xs) diff --git a/baselines/ppo2/run_mujoco.py b/baselines/ppo2/run_mujoco.py index 56fd4d9..282aa3f 100644 --- a/baselines/ppo2/run_mujoco.py +++ b/baselines/ppo2/run_mujoco.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -import argparse +import numpy as np from baselines.common.cmd_util import mujoco_arg_parser from baselines import bench, logger + def train(env_id, num_timesteps, seed): from baselines.common import set_global_seeds from baselines.common.vec_env.vec_normalize import VecNormalize @@ -16,27 +17,40 @@ def train(env_id, num_timesteps, seed): intra_op_parallelism_threads=ncpu, inter_op_parallelism_threads=ncpu) tf.Session(config=config).__enter__() + def make_env(): env = gym.make(env_id) - env = bench.Monitor(env, logger.get_dir()) + env = bench.Monitor(env, logger.get_dir(), allow_early_resets=True) return env + env = DummyVecEnv([make_env]) env = VecNormalize(env) set_global_seeds(seed) policy = MlpPolicy - ppo2.learn(policy=policy, env=env, nsteps=2048, nminibatches=32, - lam=0.95, gamma=0.99, noptepochs=10, log_interval=1, - ent_coef=0.0, - lr=3e-4, - cliprange=0.2, - total_timesteps=num_timesteps) + model = ppo2.learn(policy=policy, env=env, nsteps=2048, nminibatches=32, + lam=0.95, gamma=0.99, noptepochs=10, log_interval=1, + ent_coef=0.0, + lr=3e-4, + cliprange=0.2, + total_timesteps=num_timesteps) + + return model, env def main(): args = mujoco_arg_parser().parse_args() logger.configure() - train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) + model, env = train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) + + if args.play: + logger.log("Running trained model") + obs = np.zeros((env.num_envs,) + env.observation_space.shape) + obs[:] = env.reset() + while True: + actions = model.step(obs)[0] + obs[:] = env.step(actions)[0] + env.render() if __name__ == '__main__':