Files
baselines/baselines/common/distributions.py
pzhokhov b875fb7b5e release Internal changes (#800)
* joshim5 changes (width and height to WarpFrame wrapper)

* match network output with action distribution via a linear layer only if necessary (#167)

* support color vs. grayscale option in WarpFrame wrapper (#166)

* support color vs. grayscale option in WarpFrame wrapper

* Support color in other wrappers

* Updated per Peters suggestions

* fixing test failures

* ppo2 with microbatches (#168)

* pass microbatch_size to the model during construction

* microbatch fixes and test (#169)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* Peterz joshim5 subclass ppo2 model (#170)

* microbatch fixes and test

* tiny cleanup

* added assertions to the test

* vpg-related fix

* subclassing the model to make microbatched version of model WIP

* made microbatched model a subclass of ppo2 Model

* flake8 complaint

* mpi-less ppo2 (resolving merge conflict)

* flake8 and mpi4py imports in ppo2/model.py

* more un-mpying

* merge master

* updates to the benchmark viewer code + autopep8 (#184)

* viz docs and syntactic sugar wip

* update viewer yaml to use persistent volume claims

* move plot_util to baselines.common, update links

* use 1Tb hard drive for results viewer

* small updates to benchmark vizualizer code

* autopep8

* autopep8

* any folder can be a benchmark

* massage games image a little bit

* fixed --preload option in app.py

* remove preload from run_viewer.sh

* remove pdb breakpoints

* update bench-viewer.yaml

* fixed bug (#185)

* fixed bug 

it's wrong to do the else statement, because no other nodes would start.

* changed the fix slightly

* Refactor her phase 1 (#194)

* add monitor to the rollout envs in her RUN BENCHMARKS her

* Slice -> Slide in her benchmarks RUN BENCHMARKS her

* run her benchmark for 200 epochs

* dummy commit to RUN BENCHMARKS her

* her benchmark for 500 epochs RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her

* disable saving of policies in her benchmark RUN BENCHMARKS her

* run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch

* run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch

* launcher refactor wip

* wip

* her works on FetchReach

* her runner refactor RUN BENCHMARKS Fetch1M

* unit test for her

* fixing warnings in mpi_average in her, skip test_fetchreach if mujoco is not present

* pickle-based serialization in her

* remove extra import from subproc_vec_env.py

* investigating differences in rollout.py

* try with old rollout code RUN BENCHMARKS her

* temporarily use DummyVecEnv in cmd_util.py RUN BENCHMARKS her

* dummy commit to RUN BENCHMARKS her

* set info_values in rollout worker in her RUN BENCHMARKS her

* bug in rollout_new.py RUN BENCHMARKS her

* fixed bug in rollout_new.py RUN BENCHMARKS her

* do not use last step because vecenv calls reset and returns obs after reset RUN BENCHMARKS her

* updated buffer sizes RUN BENCHMARKS her

* fixed loading/saving via joblib

* dust off learning from demonstrations in HER, docs, refactor

* add deprecation notice on her play and plot files

* address comments by Matthias

* 1.5 months of codegen changes (#196)

* play with resnet

* feed_dict version

* coinrun prob and more stats

* fixes to get_choices_specs & hp search

* minor prob fixes

* minor fixes

* minor

* alternative version of rl_algo stuff

* pylint fixes

* fix bugs, move node_filters to soup

* changed how get_algo works

* change how get_algo works, probably broke all tests

* continue previous refactor

* get eval_agent running again

* fixing tests

* fix tests

* fix more tests

* clean up cma stuff

* fix experiment

* minor changes to eval_agent to make ppo_metal use gpu

* make dict space work

* modify mac makefile to use conda

* recurrent layers

* play with bn and resnets

* minor hp changes

* minor

* got rid of use_fb argument and jtft (joint-train-fine-tune) functionality
built test phase directly into AlgoProb

* make new rl algos generateable

* pylint; start fixing tests

* fixing tests

* more test fixes

* pylint

* fix search

* work on search

* hack around infinite loop caused by scan

* algo search fixes

* misc changes for search expt

* enable annealing, overriding options of Op

* pylint fixes

* identity op

* achieve use_last_output through masking so it automatically works in other distributions

* fix tests

* minor

* discrete

* use_last_output to be just a preference, not a hard constraint

* pred delay, pruning

* require nontrivial inputs

* aliases for get_sm

* add probname to probs

* fixes

* small fixes

* fix tests

* fix tests

* fix tests

* minor

* test scripts

* dualgru network improvements

* minor

* work on mysterious bugs

* rcall gpu-usage command for kube

* use cache dir that’s not in code folder, so that it doesn’t get removed by rcall code rsync

* add power mode to gpu usage

* make sure train/test actually different

* remove VR for now

* minor fixes

* simplify soln_db

* minor

* big refactor of mpi eda

* improve mpieda for multitask

* - get rid of timelimit hack
- add __del__ to cleanup SubprocVecEnv

* get multitask working better

* fixes

* working on atari, various

* annotate ops with whether they’re parametrized

* minor

* gym version

* rand atari prob

* minor

* SolnDb bugfix and name change

* pyspy script

* switch conv layers

* fix roboschool/bullet3

* nenvs assertion

* fix rand atari

* get rid of blanket exception catching
fix soln_db bug

* fix rand_atari

* dynamic routing as cmdline arg

* slight modifications to test_mpi_map and pyspy-all

* max_tries argument for run_until_successs

* dedup option in train_mle

* simplify soln_db

* increase atari horizon for 1 experiment

* start implementing reward increment

* ent multiplier

* create cc dsl
other misc fixes

* cc ops

* q_func -> qs in rl_algos_cc.py

* fix PredictDistr

* rl_ops_cc fixes, MakeAction op

* augment algo agent to support cc stuff

* work on ddpg experiments

* fix blocking
temporarily change logger

* allow layer scaling

* pylint fixes

* spawn_method

* isolate ddpg hacks

* improve pruning

* use spawn for subproc

* remove use of python -c in rcall

* fix pylint warning

* fix static

* maybe fix local backend

* switch to DummyVecEnv

* making some fixes via pylint

* pylint fixes

* fixing tests

* fix tests

* fix tests

* write scaffolding for SSL in Codegen

* logger fix

* fix error

* add EMA op to sl_ops

* save many changes

* save

* add upsampler

* add sl ops, enhance state machine

* get ssl search working — some gross hacking

* fix session/graph issue

* fix importing

* work on mle

* - scale embeddings in gru model
- better exception handling in sl_prob
- use emas for test/val
- use non-contrib batch_norm layer

* improve logging

* option to average before dumping in logger

* default arguments, etc

* new ddpg and identity test

* concat fix

* minor

* move realistic ssl stuff to third-party (underscore to dash)

* fixes

* remove realistic_ssl_evaluation

* pylint fixes

* use gym master

* try again

* pass around args without gin

* fix tests

* separate line to install gym

* rename failing tests that should be ignored

* add data aug

* ssl improvements

* use fixed time limit

* try to fix baselines tests

* add score_floor, max_walltime, fiddle with lr decay

* realistic_ssl

* autopep8

* various ssl
- enable blocking grad for simplification
- kl
- multiple final prediction

* fix pruning

* misc ssl stuff

* bring back linear schedule, don’t use allgather for collecting stats
(i’ve been getting nondeterministic errors from the old code)

* save/load weights in SSL, big stepsize

* cleanup SslProb

* fix

* get rid of kl coef

* fix simplification, lower lr

* search over hps

* minor fixes

* minor

* static analysis

* move files and rename things for improved consistency.
still broken, and just saving before making nontrivial changes

* various

* make tests pass

* move coinrun_train to codegen since it depends on codegen

* fixes

* pylint fixes

* improve tests
fix some things

* improve tests

* lint

* fix up db_info.py, tests

* mostly restore master version of envs directory, except for makefile changes

* fix tests

* improve printing

* minor fixes

* fix fixmes

* pruning test

* fixes

* lint

* write new test that makes tf graphs of random algos; fix some bugs it caught

* add —delete flag to rcall upload-code command

* lint

* get cifar10 lazily for testing purposes

* disable codegen ci tests for now

* clean up rl_ops

* rename spec classes

* td3 with identity test

* identity tests without gin files

* remove gin.configurable from AlgoAgent

* comments about reduction in rl_ops_cc

* address @pzhokhov comments

* fix tests

* more linting

* better tests

* clean up filtering a bit

* fix concat

* delayed logger configuration (#208)

* delayed logger configuration

* fix typo

* setters and getters for Logger.DEFAULT as well

* do away with fancy property stuff - unable to get it to work with class level methods

* grammar and spaces

* spaces

* use get_current function instead of reading Logger.CURRENT

* autopep8

* disable mpi in subprocesses (#213)

* lazy_mpi load

* cleanups

* more lazy mpi

* don't pretend that class is a module, just use it as a class

* mass-replace mpi4py imports

* flake8

* fix previous lazy_mpi imports

* silly recursion

* try os.environ hack

* better prefix test, work with mpich

* restored MPI imports

* removed commented import in test_with_mpi

* restored codegen from master

* remove lazy mpi

* restored changes from rl-algs

* remove extra files

* address Chris' comments

* use spawn for shmem vec env as well (#2) (#219)

* lazy_mpi load

* cleanups

* more lazy mpi

* don't pretend that class is a module, just use it as a class

* mass-replace mpi4py imports

* flake8

* fix previous lazy_mpi imports

* silly recursion

* try os.environ hack

* better prefix test, work with mpich

* restored MPI imports

* removed commented import in test_with_mpi

* restored codegen from master

* remove lazy mpi

* restored changes from rl-algs

* remove extra files

* port mpi fix to shmem vec env

* increase the mpi test default timeout

* change humanoid hyperparameters, get rid of clip_Frac annealing, as it's apparently dangerous

* remove clip_frac schedule from ppo2

* more timesteps in humanoid run

* whitespace + RUN BENCHMARKS

* baselines: export vecenvs from folder (#221)

* baselines: export vecenvs from folder

* put missing function back in

* add missing imports

* more imports

* longer mpi timeout?

* make default logger configuration the same as call to logger.configure() (#222)

* Vecenv refactor (#223)

* update karl util

* restore pvi flag

* change rcall auto cpu behavior, move gin.configurable, add os.makedirs

* vecenv refactor

* aux buf index fix

* add num aux obs

* reset level with enter

* restore high difficulty flag

* bugfix

* restore train_coinrun.py

* tweaks

* renaming

* renaming

* better arguments handling

* more options

* options cleanup

* game data refactor

* more options

* args for train_procgen

* add close handler to interactive base class

* use debug build if debug=True, fix range on aux_obs

* add ProcGenEnv to __init__.py, add missing imports to procgen.py

* export RemoveDictWrapper and build, update train_procgen.py, move assets download into env creation and replace init_assets_and_build with just build

* fix formatting issues

* only call global init once

* fix path in setup.py

* revert part of makefile

* ignore IDE files and folders

* vec remove dict

* export VecRemoveDictObs

* remove RemoveDictWrapper

* remove IDE files

* move shared .h and .cpp files to common folder, update build to use those, dedupe env.cpp

* fix missing header

* try unified build function

* remove old scripts dir

* add comment on build

* upload libenv with render fixes

* tell qthreads to die when we unload the library

* pyglet.app.run is garbage

* static fixes

* whoops

* actually vsync is on

* cleanup

* cleanup

* extern C for libenv interface

* parse util rcall arg

* high difficulty fix

* game type enums

* ProcGenEnv subclasses

* game type cleanup

* unrecognized key

* unrecognized game type

* parse util reorg

* args management

* typo fix

* GinParser

* arg tweaks

* tweak

* restore start_level/num_levels setting

* fix create_procgen_env interface

* build fix

* procgen args in init signature

* fix

* build fix

* fix logger usage in ppo_metal/run_retro

* removed unnecessary OrderedDict requirement in subproc_vec_env

* flake8 fix

* allow for non-mpi tests

* mpi test fixes

* flake8; removed special logic for discrete spaces in dummy_vec_env

* remove forked argument in front of tests - does not play nicely with subprocvecenv in spawned processes; analog of forked in ddpg/test_smoke

* Everyrl initial commit & a few minor baselines changes (#226)

* everyrl initial commit

* add keep_buf argument to VecMonitor

* logger changes: set_comm and fix to mpi_mean functionality

* if filename not provided, don't create ResultsWriter

* change variable syncing function to simplify its usage. now you should initialize from all mpi processes

* everyrl coinrun changes

* tf_distr changes, bugfix

* get_one

* bring back get_next to temporarily restore code

* lint fixes

* fix test

* rename profile function

* rename gaussian

* fix coinrun training script

* change random seeding to work with new gym version (#231)

* change random seeding to work with new gym version

* move seeding to seed() method

* fix mnistenv

* actually try some of the tests before pushing

* more deterministic fixed seq

* misc changes to vecenvs and run.py for benchmarks (#236)

* misc changes to vecenvs and run.py for benchmarks

* dont seed global gen

* update more references to assert_venvs_equal

* Rl19 (#232)

* everyrl initial commit

* add keep_buf argument to VecMonitor

* logger changes: set_comm and fix to mpi_mean functionality

* if filename not provided, don't create ResultsWriter

* change variable syncing function to simplify its usage. now you should initialize from all mpi processes

* everyrl coinrun changes

* tf_distr changes, bugfix

* get_one

* bring back get_next to temporarily restore code

* lint fixes

* fix test

* rename profile function

* rename gaussian

* fix coinrun training script

* rl19

* remove everyrl dir which appeared in the merge for some reason

* readme

* fiddle with ddpg

* make ddpg work

* steps_total argument

* gpu count

* clean up hyperparams and shape math

* logging + saving

* configuration stuff

* fixes, smoke tests

* fix stats

* make load_results return dicts -- easier to create the same kind of objects with some other mechanism for passing to downstream functions

* benchmarks

* fix tests

* add dqn to tests, fix it

* minor

* turned annotated transformer (pytorch) into a script

* more refactoring

* jax stuff

* cluster

* minor

* copy & paste alec code

* sign error

* add huber, rename some parameters, snapshotting off by default

* remove jax stuff

* minor

* move maze env

* minor

* remove trailing spaces

* remove trailing space

* lint

* fix test breakage due to gym update

* rename function

* move maze back to codegen

* get recurrent ppo working

* enable both lstm and gru

* script to print table of benchmark results

* various

* fix dqn

* add fixup initializer, remove lastrew

* organize logging stats

* fix silly bug

* refactor models

* fix mpi usage

* check sync

* minor

* change vf coef, hps

* clean up slicing in ppo

* minor fixes

* caching transformer

* docstrings

* xf fixes

* get rid of 'B' and 'BT' arguments

* minor

* transformer example

* remove output_kind from base class until we have a better idea how to use it

* add comments, revert maze stuff

* flake8

* codegen lint

* fix codegen tests

* responded to peter's comments

* lint fixes

* minor changes to baselines (#243)

* minor changes to baselines

* fix spaces reference

* remove flake8 disable comments and fix import

* okay maybe don't add spec to vec_env

* Merge branch 'master' of github.com:openai/games

 the commit.

* flake8 complaints in baselines/her
2019-02-27 15:35:31 -08:00

356 lines
13 KiB
Python

import tensorflow as tf
import numpy as np
import baselines.common.tf_util as U
from baselines.a2c.utils import fc
from tensorflow.python.ops import math_ops
class Pd(object):
"""
A particular probability distribution
"""
def flatparam(self):
raise NotImplementedError
def mode(self):
raise NotImplementedError
def neglogp(self, x):
# Usually it's easier to define the negative logprob
raise NotImplementedError
def kl(self, other):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def sample(self):
raise NotImplementedError
def logp(self, x):
return - self.neglogp(x)
def get_shape(self):
return self.flatparam().shape
@property
def shape(self):
return self.get_shape()
def __getitem__(self, idx):
return self.__class__(self.flatparam()[idx])
class PdType(object):
"""
Parametrized family of probability distributions
"""
def pdclass(self):
raise NotImplementedError
def pdfromflat(self, flat):
return self.pdclass()(flat)
def pdfromlatent(self, latent_vector, init_scale, init_bias):
raise NotImplementedError
def param_shape(self):
raise NotImplementedError
def sample_shape(self):
raise NotImplementedError
def sample_dtype(self):
raise NotImplementedError
def param_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
def sample_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
def __eq__(self, other):
return (type(self) == type(other)) and (self.__dict__ == other.__dict__)
class CategoricalPdType(PdType):
def __init__(self, ncat):
self.ncat = ncat
def pdclass(self):
return CategoricalPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
def param_shape(self):
return [self.ncat]
def sample_shape(self):
return []
def sample_dtype(self):
return tf.int32
class MultiCategoricalPdType(PdType):
def __init__(self, nvec):
self.ncats = nvec.astype('int32')
assert (self.ncats > 0).all()
def pdclass(self):
return MultiCategoricalPd
def pdfromflat(self, flat):
return MultiCategoricalPd(self.ncats, flat)
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [len(self.ncats)]
def sample_dtype(self):
return tf.int32
class DiagGaussianPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return DiagGaussianPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
return self.pdfromflat(pdparam), mean
def param_shape(self):
return [2*self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.float32
class BernoulliPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return BernoulliPd
def param_shape(self):
return [self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.int32
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
# WRONG SECOND DERIVATIVES
# class CategoricalPd(Pd):
# def __init__(self, logits):
# self.logits = logits
# self.ps = tf.nn.softmax(logits)
# @classmethod
# def fromflat(cls, flat):
# return cls(flat)
# def flatparam(self):
# return self.logits
# def mode(self):
# return U.argmax(self.logits, axis=-1)
# def logp(self, x):
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
# def kl(self, other):
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def entropy(self):
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def sample(self):
# u = tf.random_uniform(tf.shape(self.logits))
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
class CategoricalPd(Pd):
def __init__(self, logits):
self.logits = logits
def flatparam(self):
return self.logits
def mode(self):
return tf.argmax(self.logits, axis=-1)
@property
def mean(self):
return tf.nn.softmax(self.logits)
def neglogp(self, x):
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
# the implementation does not allow second-order derivatives...
if x.dtype in {tf.uint8, tf.int32, tf.int64}:
# one-hot encoding
x_shape_list = x.shape.as_list()
logits_shape_list = self.logits.get_shape().as_list()[:-1]
for xs, ls in zip(x_shape_list, logits_shape_list):
if xs is not None and ls is not None:
assert xs == ls, 'shape mismatch: {} in x vs {} in logits'.format(xs, ls)
x = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
else:
# already encoded
assert x.shape.as_list() == self.logits.shape.as_list()
return tf.nn.softmax_cross_entropy_with_logits_v2(
logits=self.logits,
labels=x)
def kl(self, other):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
z1 = tf.reduce_sum(ea1, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)
def entropy(self):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keepdims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, axis=-1, keepdims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class MultiCategoricalPd(Pd):
def __init__(self, nvec, flat):
self.flat = flat
self.categoricals = list(map(CategoricalPd,
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
def flatparam(self):
return self.flat
def mode(self):
return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
def neglogp(self, x):
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
def kl(self, other):
return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
@classmethod
def fromflat(cls, flat):
raise NotImplementedError
class DiagGaussianPd(Pd):
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
def flatparam(self):
return self.flat
def mode(self):
return self.mean
def neglogp(self, x):
return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
+ tf.reduce_sum(self.logstd, axis=-1)
def kl(self, other):
assert isinstance(other, DiagGaussianPd)
return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
@classmethod
def fromflat(cls, flat):
return cls(flat)
class BernoulliPd(Pd):
def __init__(self, logits):
self.logits = logits
self.ps = tf.sigmoid(logits)
def flatparam(self):
return self.logits
@property
def mean(self):
return self.ps
def mode(self):
return tf.round(self.ps)
def neglogp(self, x):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1)
def kl(self, other):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def entropy(self):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.ps))
return tf.to_float(math_ops.less(u, self.ps))
@classmethod
def fromflat(cls, flat):
return cls(flat)
def make_pdtype(ac_space):
from gym import spaces
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1
return DiagGaussianPdType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
return CategoricalPdType(ac_space.n)
elif isinstance(ac_space, spaces.MultiDiscrete):
return MultiCategoricalPdType(ac_space.nvec)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
else:
raise NotImplementedError
def shape_el(v, i):
maybe = v.get_shape()[i]
if maybe is not None:
return maybe
else:
return tf.shape(v)[i]
@U.in_session
def test_probtypes():
np.random.seed(0)
pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
validate_probtype(diag_gauss, pdparam_diag_gauss)
pdparam_categorical = np.array([-.2, .3, .5])
categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
validate_probtype(categorical, pdparam_categorical)
nvec = [1,2,3]
pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101
validate_probtype(multicategorical, pdparam_multicategorical)
pdparam_bernoulli = np.array([-.2, .3, .5])
bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
validate_probtype(bernoulli, pdparam_bernoulli)
def validate_probtype(probtype, pdparam):
N = 100000
# Check to see if mean negative log likelihood == differential entropy
Mval = np.repeat(pdparam[None, :], N, axis=0)
M = probtype.param_placeholder([N])
X = probtype.sample_placeholder([N])
pd = probtype.pdfromflat(M)
calcloglik = U.function([X, M], pd.logp(X))
calcent = U.function([M], pd.entropy())
Xval = tf.get_default_session().run(pd.sample(), feed_dict={M:Mval})
logliks = calcloglik(Xval, Mval)
entval_ll = - logliks.mean() #pylint: disable=E1101
entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
entval = calcent(Mval).mean() #pylint: disable=E1101
assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
# Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
M2 = probtype.param_placeholder([N])
pd2 = probtype.pdfromflat(M2)
q = pdparam + np.random.randn(pdparam.size) * 0.1
Mval2 = np.repeat(q[None, :], N, axis=0)
calckl = U.function([M, M2], pd.kl(pd2))
klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
logliks = calcloglik(Xval, Mval2)
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
print('ok on', probtype, pdparam)
def _matching_fc(tensor, name, size, init_scale, init_bias):
if tensor.shape[-1] == size:
return tensor
else:
return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)