* sync internal changes. Make ddpg work with vecenvs * B -> nenvs for consistency with other algos, small cleanups * eval_done[d]==True -> eval_done[d] * flake8 and numpy.random.random_integers deprecation warning * store session at policy creation time * coexistence tests * fix a typo * autopep8 * ... and flake8 * updated todo links in test_serialization
187 lines
6.5 KiB
Python
187 lines
6.5 KiB
Python
import tensorflow as tf
|
|
from baselines.common import tf_util
|
|
from baselines.a2c.utils import fc
|
|
from baselines.common.distributions import make_pdtype
|
|
from baselines.common.input import observation_placeholder, encode_observation
|
|
from baselines.common.tf_util import adjust_shape
|
|
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
|
from baselines.common.models import get_network_builder
|
|
|
|
import gym
|
|
|
|
|
|
class PolicyWithValue(object):
|
|
"""
|
|
Encapsulates fields and methods for RL policy and value function estimation with shared parameters
|
|
"""
|
|
|
|
def __init__(self, env, observations, latent, estimate_q=False, vf_latent=None, sess=None, **tensors):
|
|
"""
|
|
Parameters:
|
|
----------
|
|
env RL environment
|
|
|
|
observations tensorflow placeholder in which the observations will be fed
|
|
|
|
latent latent state from which policy distribution parameters should be inferred
|
|
|
|
vf_latent latent state from which value function should be inferred (if None, then latent is used)
|
|
|
|
sess tensorflow session to run calculations in (if None, default session is used)
|
|
|
|
**tensors tensorflow tensors for additional attributes such as state or mask
|
|
|
|
"""
|
|
|
|
self.X = observations
|
|
self.state = tf.constant([])
|
|
self.initial_state = None
|
|
self.__dict__.update(tensors)
|
|
|
|
vf_latent = vf_latent if vf_latent is not None else latent
|
|
|
|
vf_latent = tf.layers.flatten(vf_latent)
|
|
latent = tf.layers.flatten(latent)
|
|
|
|
# Based on the action space, will select what probability distribution type
|
|
self.pdtype = make_pdtype(env.action_space)
|
|
|
|
self.pd, self.pi = self.pdtype.pdfromlatent(latent, init_scale=0.01)
|
|
|
|
# Take an action
|
|
self.action = self.pd.sample()
|
|
|
|
# Calculate the neg log of our probability
|
|
self.neglogp = self.pd.neglogp(self.action)
|
|
self.sess = sess or tf.get_default_session()
|
|
|
|
if estimate_q:
|
|
assert isinstance(env.action_space, gym.spaces.Discrete)
|
|
self.q = fc(vf_latent, 'q', env.action_space.n)
|
|
self.vf = self.q
|
|
else:
|
|
self.vf = fc(vf_latent, 'vf', 1)
|
|
self.vf = self.vf[:,0]
|
|
|
|
def _evaluate(self, variables, observation, **extra_feed):
|
|
sess = self.sess
|
|
feed_dict = {self.X: adjust_shape(self.X, observation)}
|
|
for inpt_name, data in extra_feed.items():
|
|
if inpt_name in self.__dict__.keys():
|
|
inpt = self.__dict__[inpt_name]
|
|
if isinstance(inpt, tf.Tensor) and inpt._op.type == 'Placeholder':
|
|
feed_dict[inpt] = adjust_shape(inpt, data)
|
|
|
|
return sess.run(variables, feed_dict)
|
|
|
|
def step(self, observation, **extra_feed):
|
|
"""
|
|
Compute next action(s) given the observation(s)
|
|
|
|
Parameters:
|
|
----------
|
|
|
|
observation observation data (either single or a batch)
|
|
|
|
**extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
|
|
|
|
Returns:
|
|
-------
|
|
(action, value estimate, next state, negative log likelihood of the action under current policy parameters) tuple
|
|
"""
|
|
|
|
a, v, state, neglogp = self._evaluate([self.action, self.vf, self.state, self.neglogp], observation, **extra_feed)
|
|
if state.size == 0:
|
|
state = None
|
|
return a, v, state, neglogp
|
|
|
|
def value(self, ob, *args, **kwargs):
|
|
"""
|
|
Compute value estimate(s) given the observation(s)
|
|
|
|
Parameters:
|
|
----------
|
|
|
|
observation observation data (either single or a batch)
|
|
|
|
**extra_feed additional data such as state or mask (names of the arguments should match the ones in constructor, see __init__)
|
|
|
|
Returns:
|
|
-------
|
|
value estimate
|
|
"""
|
|
return self._evaluate(self.vf, ob, *args, **kwargs)
|
|
|
|
def save(self, save_path):
|
|
tf_util.save_state(save_path, sess=self.sess)
|
|
|
|
def load(self, load_path):
|
|
tf_util.load_state(load_path, sess=self.sess)
|
|
|
|
def build_policy(env, policy_network, value_network=None, normalize_observations=False, estimate_q=False, **policy_kwargs):
|
|
if isinstance(policy_network, str):
|
|
network_type = policy_network
|
|
policy_network = get_network_builder(network_type)(**policy_kwargs)
|
|
|
|
def policy_fn(nbatch=None, nsteps=None, sess=None, observ_placeholder=None):
|
|
ob_space = env.observation_space
|
|
|
|
X = observ_placeholder if observ_placeholder is not None else observation_placeholder(ob_space, batch_size=nbatch)
|
|
|
|
extra_tensors = {}
|
|
|
|
if normalize_observations and X.dtype == tf.float32:
|
|
encoded_x, rms = _normalize_clip_observation(X)
|
|
extra_tensors['rms'] = rms
|
|
else:
|
|
encoded_x = X
|
|
|
|
encoded_x = encode_observation(ob_space, encoded_x)
|
|
|
|
with tf.variable_scope('pi', reuse=tf.AUTO_REUSE):
|
|
policy_latent = policy_network(encoded_x)
|
|
if isinstance(policy_latent, tuple):
|
|
policy_latent, recurrent_tensors = policy_latent
|
|
|
|
if recurrent_tensors is not None:
|
|
# recurrent architecture, need a few more steps
|
|
nenv = nbatch // nsteps
|
|
assert nenv > 0, 'Bad input for recurrent policy: batch size {} smaller than nsteps {}'.format(nbatch, nsteps)
|
|
policy_latent, recurrent_tensors = policy_network(encoded_x, nenv)
|
|
extra_tensors.update(recurrent_tensors)
|
|
|
|
|
|
_v_net = value_network
|
|
|
|
if _v_net is None or _v_net == 'shared':
|
|
vf_latent = policy_latent
|
|
else:
|
|
if _v_net == 'copy':
|
|
_v_net = policy_network
|
|
else:
|
|
assert callable(_v_net)
|
|
|
|
with tf.variable_scope('vf', reuse=tf.AUTO_REUSE):
|
|
# TODO recurrent architectures are not supported with value_network=copy yet
|
|
vf_latent = _v_net(encoded_x)
|
|
|
|
policy = PolicyWithValue(
|
|
env=env,
|
|
observations=X,
|
|
latent=policy_latent,
|
|
vf_latent=vf_latent,
|
|
sess=sess,
|
|
estimate_q=estimate_q,
|
|
**extra_tensors
|
|
)
|
|
return policy
|
|
|
|
return policy_fn
|
|
|
|
|
|
def _normalize_clip_observation(x, clip_range=[-5.0, 5.0]):
|
|
rms = RunningMeanStd(shape=x.shape[1:])
|
|
norm_x = tf.clip_by_value((x - rms.mean) / rms.std, min(clip_range), max(clip_range))
|
|
return norm_x, rms
|
|
|