Files
baselines/baselines/common/distributions.py
pzhokhov fe06c6b4db continuous action spaces for codegen + some benchmarking (#82)
* add some docstrings

* start making big changes

* state machine redesign

* sampling seems to work

* some reorg

* fixed sampling of real vals

* json conversion

* made it possible to register new commands
got nontrivial version of Pred working

* consolidate command definitions

* add more macro blocks

* revived visualization

* rename Userdata -> CmdInterpreter
make AlgoSmInstance subclass of SmInstance that uses appropriate userdata argument

* replace userdata by ci when appropriate

* minor test fixes

* revamped handmade dir, can run ppo_metal

* seed to avoid random test failure

* implement AlgoAgent

* Autogenerated object that performs all ops and macros

* more CmdRecorder changes

* move files around

* move MatchProb and JtftProb

* remove obsolete

* fix tests involving AlgoAgent (pending the next commit on ppo_metal code)

* ppo_metal: reduce duplication in policy_gen, make sess an attribute of PpoAgent and StochasticPolicy instead of using get_default_session everywhere.

* maze_env reformatting, move algo_search script (but stil broken)

* move agent.py

* fix test on handcrafted agents

* tuning/fixing ppo_metal baseline

* minor

* Fix ppo_metal baseline

* Don’t set epcount, tcount unless they’re being used

* get rid of old ppo_metal baseline

* fixes for handmade/run.py tuning

* fix codegen ppo

* fix handmade ppo hps

* fix test, go back to safe_div

* switch to more complex filtering

* make sure all handcrafted algos have finite probability

* train to maximize logprob of provided samples
Trex changes to avoid segfault

* AlgoSm also includes global hyperparams

* don’t duplicate global hyperparam defaults

* create generic_ob_ac_space function

* use sorted list of outkeys

* revive tsne

* todo changes

* determinism test

* todo + test fix

* remove a few deprecated files, rename other tests so they don’t run automatically, fix real test failure

* continuous control with codegen

* continuous control with codegen

* implement continuous action space algodistr

* ppo with trex RUN BENCHMARKS

* wrap trex in a monitor

* dummy commit to RUN BENCHMARKS

* adding monitor to trex env RUN BENCHMARKS

* adding monitor to trex RUN BENCHMARKS

* include monitor into trex env RUN BENCHMARKS

* generate nll and predmean using Distribution node

* dummy commit to RUN BENCHMARKS

* include pybullet into baselines optional dependencies

* dummy commit to RUN BENCHMARKS

* install games for cron rcall user RUN BENCHMARKS

* add --yes flag to install.py in rcall config for cron user RUN BENCHMARKS

* both continuous and discrete versions seem to run

* fixes to monitor to work with vecenv-like info and rewards RUN BENCHMARKS

* dummy commit to RUN BENCHMARKS

* removed shape check from one-hot encoding logic in distributions.CategoricalPd

* reset logger configuration in codegen/handmade/run.py to be in-line with baselines RUN BENCHMARKS

* merged peterz_codegen_benchmarks RUN BENCHMARKS

* skip tests RUN BENCHMARKS

* working on test failures

* save benchmark dicts RUN BENCHMARK

* merged peterz_codegen_benchmark RUN BENCHMARKS

* add get_git_commit_message to the baselines.common.console_util

* dummy commit to RUN BENCHMARKS

* merged fixes from peterz_codegen_benchmark RUN BENCHMARKS

* fixing failure in test_algo_nll WIP

* test_algo_nll passes with both ppo and softq

* re-enabled tests

* run trex on gpus for 100k total (horizon=100k / 16) RUN BENCHMARKS

* merged latest peterz_codegen_benchmarks RUN BENCHMARKS

* fixing codegen test failures (logging-related)

* fixed name collision in run-benchmarks-new.py RUN BENCHMARKS

* fixed name collision in run-benchmarks-new.py RUN BENCHMARKS

* fixed import in node_filters.py

* test_algo_search passes

* some cleanup

* dummy commit to RUN BENCHMARKS

* merge fast fail for subprocvecenv RUN BENCHMARKS

* use SubprocVecEnv in sonic_prob

* added deprecation note to shmem_vec_env

* allow indexing of distributions

* add timeout to pipeline.yaml

* typo in pipeline.yml

* run tests with --forked option

* resolved merge conflict in rl_algs.bench.benchmarks

* re-enable parallel tests

* fix remaining merge conflicts and syntax

* Update trex_prob.py

* fixes to ResultsWriter

* take baselines/run.py from peterz_codegen branch

* actually save stuff to file in VecMonitor RUN BENCHMARKS

* enable parallel tests

* merge stricter flake8

* merge peterz_codegen_benchmark, resolve conflicts

* autopep8

* remove traces of Monitor from trex env, check shapes before encoding in CategoricalPd

* asserts and warnings to make q -> distribution change more explicit

* fixed assert in CategoricalPd

* add header to vec_monitor output file RUN BENCHMARKS

* make VecMonitor write header to the output file

* remove deprecation message from shmem_vec_env RUN BENCHMARKS

* autopep8

* proper shape test in distributions.py

* ResultsWriter can take dict headers

* dummy commit to RUN BENCHMARKS

* replace assert len(qs)==1 with warning RUN BENCHMARKS

* removed pdb from ppo2 RUN BENCHMARKS
2018-09-14 15:43:49 -07:00

336 lines
12 KiB
Python

import tensorflow as tf
import numpy as np
import baselines.common.tf_util as U
from baselines.a2c.utils import fc
from tensorflow.python.ops import math_ops
class Pd(object):
"""
A particular probability distribution
"""
def flatparam(self):
raise NotImplementedError
def mode(self):
raise NotImplementedError
def neglogp(self, x):
# Usually it's easier to define the negative logprob
raise NotImplementedError
def kl(self, other):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def sample(self):
raise NotImplementedError
def logp(self, x):
return - self.neglogp(x)
def get_shape(self):
return self.flatparam().shape
@property
def shape(self):
return self.get_shape()
class PdType(object):
"""
Parametrized family of probability distributions
"""
def pdclass(self):
raise NotImplementedError
def pdfromflat(self, flat):
return self.pdclass()(flat)
def pdfromlatent(self, latent_vector):
raise NotImplementedError
def param_shape(self):
raise NotImplementedError
def sample_shape(self):
raise NotImplementedError
def sample_dtype(self):
raise NotImplementedError
def param_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
def sample_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
class CategoricalPdType(PdType):
def __init__(self, ncat):
self.ncat = ncat
def pdclass(self):
return CategoricalPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
def param_shape(self):
return [self.ncat]
def sample_shape(self):
return []
def sample_dtype(self):
return tf.int32
class MultiCategoricalPdType(PdType):
def __init__(self, nvec):
self.ncats = nvec
def pdclass(self):
return MultiCategoricalPd
def pdfromflat(self, flat):
return MultiCategoricalPd(self.ncats, flat)
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [len(self.ncats)]
def sample_dtype(self):
return tf.int32
class DiagGaussianPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return DiagGaussianPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
return self.pdfromflat(pdparam), mean
def param_shape(self):
return [2*self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.float32
class BernoulliPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return BernoulliPd
def param_shape(self):
return [self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.int32
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
# WRONG SECOND DERIVATIVES
# class CategoricalPd(Pd):
# def __init__(self, logits):
# self.logits = logits
# self.ps = tf.nn.softmax(logits)
# @classmethod
# def fromflat(cls, flat):
# return cls(flat)
# def flatparam(self):
# return self.logits
# def mode(self):
# return U.argmax(self.logits, axis=-1)
# def logp(self, x):
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
# def kl(self, other):
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def entropy(self):
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def sample(self):
# u = tf.random_uniform(tf.shape(self.logits))
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
class CategoricalPd(Pd):
def __init__(self, logits):
self.logits = logits
def flatparam(self):
return self.logits
def mode(self):
return tf.argmax(self.logits, axis=-1)
def neglogp(self, x):
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
# the implementation does not allow second-order derivatives...
if x.dtype in {tf.uint8, tf.int32, tf.int64}:
# one-hot encoding
x_shape_list = x.shape.as_list()
logits_shape_list = self.logits.get_shape().as_list()[:-1]
for xs, ls in zip(x_shape_list, logits_shape_list):
if xs is not None and ls is not None:
assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)
x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
else:
# already encoded
assert x.shape.as_list() == self.logits.shape.as_list()
return tf.nn.softmax_cross_entropy_with_logits_v2(
logits=self.logits,
labels=x)
def kl(self, other):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)
def entropy(self):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class MultiCategoricalPd(Pd):
def __init__(self, nvec, flat):
self.flat = flat
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1)))
def flatparam(self):
return self.flat
def mode(self):
return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
def neglogp(self, x):
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
def kl(self, other):
return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
@classmethod
def fromflat(cls, flat):
raise NotImplementedError
class DiagGaussianPd(Pd):
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
def flatparam(self):
return self.flat
def mode(self):
return self.mean
def neglogp(self, x):
return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
+ tf.reduce_sum(self.logstd, axis=-1)
def kl(self, other):
assert isinstance(other, DiagGaussianPd)
return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
@classmethod
def fromflat(cls, flat):
return cls(flat)
def __getitem__(self, idx):
return DiagGaussianPd(self.flat[idx])
class BernoulliPd(Pd):
def __init__(self, logits):
self.logits = logits
self.ps = tf.sigmoid(logits)
def flatparam(self):
return self.logit
@property
def mean(self):
return self.ps
def mode(self):
return tf.round(self.ps)
def neglogp(self, x):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1)
def kl(self, other):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def entropy(self):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.ps))
return tf.to_float(math_ops.less(u, self.ps))
@classmethod
def fromflat(cls, flat):
return cls(flat)
def make_pdtype(ac_space):
from gym import spaces
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1
return DiagGaussianPdType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
return CategoricalPdType(ac_space.n)
elif isinstance(ac_space, spaces.MultiDiscrete):
return MultiCategoricalPdType(ac_space.nvec)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
else:
raise NotImplementedError
def shape_el(v, i):
maybe = v.get_shape()[i]
if maybe is not None:
return maybe
else:
return tf.shape(v)[i]
@U.in_session
def test_probtypes():
np.random.seed(0)
pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
validate_probtype(diag_gauss, pdparam_diag_gauss)
pdparam_categorical = np.array([-.2, .3, .5])
categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
validate_probtype(categorical, pdparam_categorical)
nvec = [1,2,3]
pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101
validate_probtype(multicategorical, pdparam_multicategorical)
pdparam_bernoulli = np.array([-.2, .3, .5])
bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
validate_probtype(bernoulli, pdparam_bernoulli)
def validate_probtype(probtype, pdparam):
N = 100000
# Check to see if mean negative log likelihood == differential entropy
Mval = np.repeat(pdparam[None, :], N, axis=0)
M = probtype.param_placeholder([N])
X = probtype.sample_placeholder([N])
pd = probtype.pdfromflat(M)
calcloglik = U.function([X, M], pd.logp(X))
calcent = U.function([M], pd.entropy())
Xval = tf.get_default_session().run(pd.sample(), feed_dict={M:Mval})
logliks = calcloglik(Xval, Mval)
entval_ll = - logliks.mean() #pylint: disable=E1101
entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
entval = calcent(Mval).mean() #pylint: disable=E1101
assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
# Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
M2 = probtype.param_placeholder([N])
pd2 = probtype.pdfromflat(M2)
q = pdparam + np.random.randn(pdparam.size) * 0.1
Mval2 = np.repeat(q[None, :], N, axis=0)
calckl = U.function([M, M2], pd.kl(pd2))
klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
logliks = calcloglik(Xval, Mval2)
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
print('ok on', probtype, pdparam)