From 8e56ddeac296deab3cc0adc79c84a5abb59d7a3a Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Wed, 24 Oct 2018 11:01:59 -0700 Subject: [PATCH] Multidiscrete action space compatibility for policy gradient-based methods (#677) * multidiscrete space compatibility * flake8 and syntax --- baselines/acktr/acktr.py | 10 +++++----- baselines/common/distributions.py | 7 ++++++- baselines/common/input.py | 16 ++++++++++++---- baselines/common/tests/envs/identity_env.py | 15 ++++++++++++++- baselines/common/tests/test_identity.py | 20 ++++++++++++++++++-- baselines/common/vec_env/dummy_vec_env.py | 3 +++ 6 files changed, 58 insertions(+), 13 deletions(-) diff --git a/baselines/acktr/acktr.py b/baselines/acktr/acktr.py index dcbe612..10ab32b 100644 --- a/baselines/acktr/acktr.py +++ b/baselines/acktr/acktr.py @@ -21,16 +21,16 @@ class Model(object): self.sess = sess = get_session() nbatch = nenvs * nsteps - A = tf.placeholder(ac_space.dtype, [nbatch,] + list(ac_space.shape)) + with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE): + self.model = step_model = policy(nenvs, 1, sess=sess) + self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess) + + A = train_model.pdtype.sample_placeholder([None]) ADV = tf.placeholder(tf.float32, [nbatch]) R = tf.placeholder(tf.float32, [nbatch]) PG_LR = tf.placeholder(tf.float32, []) VF_LR = tf.placeholder(tf.float32, []) - with tf.variable_scope('acktr_model', reuse=tf.AUTO_REUSE): - self.model = step_model = policy(nenvs, 1, sess=sess) - self.model2 = train_model = policy(nenvs*nsteps, nsteps, sess=sess) - neglogpac = train_model.pd.neglogp(A) self.logits = train_model.pi diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 491b9ff..5b3e7be 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -39,7 +39,7 @@ class PdType(object): raise NotImplementedError def pdfromflat(self, flat): return self.pdclass()(flat) - def pdfromlatent(self, latent_vector): + def pdfromlatent(self, latent_vector, init_scale, init_bias): raise NotImplementedError def param_shape(self): raise NotImplementedError @@ -80,6 +80,11 @@ class MultiCategoricalPdType(PdType): return MultiCategoricalPd def pdfromflat(self, flat): return MultiCategoricalPd(self.ncats, flat) + + def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0): + pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias) + return self.pdfromflat(pdparam), pdparam + def param_shape(self): return [sum(self.ncats)] def sample_shape(self): diff --git a/baselines/common/input.py b/baselines/common/input.py index 7d51008..ebaf30a 100644 --- a/baselines/common/input.py +++ b/baselines/common/input.py @@ -1,5 +1,6 @@ +import numpy as np import tensorflow as tf -from gym.spaces import Discrete, Box +from gym.spaces import Discrete, Box, MultiDiscrete def observation_placeholder(ob_space, batch_size=None, name='Ob'): ''' @@ -20,10 +21,14 @@ def observation_placeholder(ob_space, batch_size=None, name='Ob'): tensorflow placeholder tensor ''' - assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box), \ + assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \ 'Can only deal with Discrete and Box observation spaces for now' - return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=ob_space.dtype, name=name) + dtype = ob_space.dtype + if dtype == np.int8: + dtype = np.uint8 + + return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name) def observation_input(ob_space, batch_size=None, name='Ob'): @@ -48,9 +53,12 @@ def encode_observation(ob_space, placeholder): ''' if isinstance(ob_space, Discrete): return tf.to_float(tf.one_hot(placeholder, ob_space.n)) - elif isinstance(ob_space, Box): return tf.to_float(placeholder) + elif isinstance(ob_space, MultiDiscrete): + placeholder = tf.cast(placeholder, tf.int32) + one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])] + return tf.concat(one_hots, axis=-1) else: raise NotImplementedError diff --git a/baselines/common/tests/envs/identity_env.py b/baselines/common/tests/envs/identity_env.py index 005d3ff..4429f04 100644 --- a/baselines/common/tests/envs/identity_env.py +++ b/baselines/common/tests/envs/identity_env.py @@ -1,7 +1,7 @@ import numpy as np from abc import abstractmethod from gym import Env -from gym.spaces import Discrete, Box +from gym.spaces import MultiDiscrete, Discrete, Box class IdentityEnv(Env): @@ -53,6 +53,19 @@ class DiscreteIdentityEnv(IdentityEnv): def _get_reward(self, actions): return 1 if self.state == actions else 0 +class MultiDiscreteIdentityEnv(IdentityEnv): + def __init__( + self, + dims, + episode_len=None, + ): + + self.action_space = MultiDiscrete(dims) + super().__init__(episode_len=episode_len) + + def _get_reward(self, actions): + return 1 if all(self.state == actions) else 0 + class BoxIdentityEnv(IdentityEnv): def __init__( diff --git a/baselines/common/tests/test_identity.py b/baselines/common/tests/test_identity.py index 0b3c46e..c950e5a 100644 --- a/baselines/common/tests/test_identity.py +++ b/baselines/common/tests/test_identity.py @@ -1,5 +1,5 @@ import pytest -from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv +from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv from baselines.run import get_learn_function from baselines.common.tests.util import simple_test @@ -21,6 +21,7 @@ learn_kwargs = { algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] +algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi'] algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi'] @pytest.mark.slow @@ -38,6 +39,21 @@ def test_discrete_identity(alg): env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100) simple_test(env_fn, learn_fn, 0.9) +@pytest.mark.slow +@pytest.mark.parametrize("alg", algos_multidisc) +def test_multidiscrete_identity(alg): + ''' + Test if the algorithm (with an mlp policy) + can learn an identity transformation (i.e. return observation as an action) + ''' + + kwargs = learn_kwargs[alg] + kwargs.update(common_kwargs) + + learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs) + env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100) + simple_test(env_fn, learn_fn, 0.9) + @pytest.mark.slow @pytest.mark.parametrize("alg", algos_cont) def test_continuous_identity(alg): @@ -55,5 +71,5 @@ def test_continuous_identity(alg): simple_test(env_fn, learn_fn, -0.1) if __name__ == '__main__': - test_continuous_identity('ddpg') + test_multidiscrete_identity('acktr') diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index 60db11d..2b4d2ba 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -20,8 +20,11 @@ class DummyVecEnv(VecEnv): env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) obs_space = env.observation_space + if isinstance(obs_space, spaces.MultiDiscrete): + obs_space.shape = obs_space.shape[0] self.keys, shapes, dtypes = obs_space_info(obs_space) + self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)