* joshim5 changes (width and height to WarpFrame wrapper) * match network output with action distribution via a linear layer only if necessary (#167) * support color vs. grayscale option in WarpFrame wrapper (#166) * support color vs. grayscale option in WarpFrame wrapper * Support color in other wrappers * Updated per Peters suggestions * fixing test failures * ppo2 with microbatches (#168) * pass microbatch_size to the model during construction * microbatch fixes and test (#169) * microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix * Peterz joshim5 subclass ppo2 model (#170) * microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix * subclassing the model to make microbatched version of model WIP * made microbatched model a subclass of ppo2 Model * flake8 complaint * mpi-less ppo2 (resolving merge conflict) * flake8 and mpi4py imports in ppo2/model.py * more un-mpying * merge master * updates to the benchmark viewer code + autopep8 (#184) * viz docs and syntactic sugar wip * update viewer yaml to use persistent volume claims * move plot_util to baselines.common, update links * use 1Tb hard drive for results viewer * small updates to benchmark vizualizer code * autopep8 * autopep8 * any folder can be a benchmark * massage games image a little bit * fixed --preload option in app.py * remove preload from run_viewer.sh * remove pdb breakpoints * update bench-viewer.yaml * fixed bug (#185) * fixed bug it's wrong to do the else statement, because no other nodes would start. * changed the fix slightly * Refactor her phase 1 (#194) * add monitor to the rollout envs in her RUN BENCHMARKS her * Slice -> Slide in her benchmarks RUN BENCHMARKS her * run her benchmark for 200 epochs * dummy commit to RUN BENCHMARKS her * her benchmark for 500 epochs RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * disable saving of policies in her benchmark RUN BENCHMARKS her * run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch * run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch * launcher refactor wip * wip * her works on FetchReach * her runner refactor RUN BENCHMARKS Fetch1M * unit test for her * fixing warnings in mpi_average in her, skip test_fetchreach if mujoco is not present * pickle-based serialization in her * remove extra import from subproc_vec_env.py * investigating differences in rollout.py * try with old rollout code RUN BENCHMARKS her * temporarily use DummyVecEnv in cmd_util.py RUN BENCHMARKS her * dummy commit to RUN BENCHMARKS her * set info_values in rollout worker in her RUN BENCHMARKS her * bug in rollout_new.py RUN BENCHMARKS her * fixed bug in rollout_new.py RUN BENCHMARKS her * do not use last step because vecenv calls reset and returns obs after reset RUN BENCHMARKS her * updated buffer sizes RUN BENCHMARKS her * fixed loading/saving via joblib * dust off learning from demonstrations in HER, docs, refactor * add deprecation notice on her play and plot files * address comments by Matthias * 1.5 months of codegen changes (#196) * play with resnet * feed_dict version * coinrun prob and more stats * fixes to get_choices_specs & hp search * minor prob fixes * minor fixes * minor * alternative version of rl_algo stuff * pylint fixes * fix bugs, move node_filters to soup * changed how get_algo works * change how get_algo works, probably broke all tests * continue previous refactor * get eval_agent running again * fixing tests * fix tests * fix more tests * clean up cma stuff * fix experiment * minor changes to eval_agent to make ppo_metal use gpu * make dict space work * modify mac makefile to use conda * recurrent layers * play with bn and resnets * minor hp changes * minor * got rid of use_fb argument and jtft (joint-train-fine-tune) functionality built test phase directly into AlgoProb * make new rl algos generateable * pylint; start fixing tests * fixing tests * more test fixes * pylint * fix search * work on search * hack around infinite loop caused by scan * algo search fixes * misc changes for search expt * enable annealing, overriding options of Op * pylint fixes * identity op * achieve use_last_output through masking so it automatically works in other distributions * fix tests * minor * discrete * use_last_output to be just a preference, not a hard constraint * pred delay, pruning * require nontrivial inputs * aliases for get_sm * add probname to probs * fixes * small fixes * fix tests * fix tests * fix tests * minor * test scripts * dualgru network improvements * minor * work on mysterious bugs * rcall gpu-usage command for kube * use cache dir that’s not in code folder, so that it doesn’t get removed by rcall code rsync * add power mode to gpu usage * make sure train/test actually different * remove VR for now * minor fixes * simplify soln_db * minor * big refactor of mpi eda * improve mpieda for multitask * - get rid of timelimit hack - add __del__ to cleanup SubprocVecEnv * get multitask working better * fixes * working on atari, various * annotate ops with whether they’re parametrized * minor * gym version * rand atari prob * minor * SolnDb bugfix and name change * pyspy script * switch conv layers * fix roboschool/bullet3 * nenvs assertion * fix rand atari * get rid of blanket exception catching fix soln_db bug * fix rand_atari * dynamic routing as cmdline arg * slight modifications to test_mpi_map and pyspy-all * max_tries argument for run_until_successs * dedup option in train_mle * simplify soln_db * increase atari horizon for 1 experiment * start implementing reward increment * ent multiplier * create cc dsl other misc fixes * cc ops * q_func -> qs in rl_algos_cc.py * fix PredictDistr * rl_ops_cc fixes, MakeAction op * augment algo agent to support cc stuff * work on ddpg experiments * fix blocking temporarily change logger * allow layer scaling * pylint fixes * spawn_method * isolate ddpg hacks * improve pruning * use spawn for subproc * remove use of python -c in rcall * fix pylint warning * fix static * maybe fix local backend * switch to DummyVecEnv * making some fixes via pylint * pylint fixes * fixing tests * fix tests * fix tests * write scaffolding for SSL in Codegen * logger fix * fix error * add EMA op to sl_ops * save many changes * save * add upsampler * add sl ops, enhance state machine * get ssl search working — some gross hacking * fix session/graph issue * fix importing * work on mle * - scale embeddings in gru model - better exception handling in sl_prob - use emas for test/val - use non-contrib batch_norm layer * improve logging * option to average before dumping in logger * default arguments, etc * new ddpg and identity test * concat fix * minor * move realistic ssl stuff to third-party (underscore to dash) * fixes * remove realistic_ssl_evaluation * pylint fixes * use gym master * try again * pass around args without gin * fix tests * separate line to install gym * rename failing tests that should be ignored * add data aug * ssl improvements * use fixed time limit * try to fix baselines tests * add score_floor, max_walltime, fiddle with lr decay * realistic_ssl * autopep8 * various ssl - enable blocking grad for simplification - kl - multiple final prediction * fix pruning * misc ssl stuff * bring back linear schedule, don’t use allgather for collecting stats (i’ve been getting nondeterministic errors from the old code) * save/load weights in SSL, big stepsize * cleanup SslProb * fix * get rid of kl coef * fix simplification, lower lr * search over hps * minor fixes * minor * static analysis * move files and rename things for improved consistency. still broken, and just saving before making nontrivial changes * various * make tests pass * move coinrun_train to codegen since it depends on codegen * fixes * pylint fixes * improve tests fix some things * improve tests * lint * fix up db_info.py, tests * mostly restore master version of envs directory, except for makefile changes * fix tests * improve printing * minor fixes * fix fixmes * pruning test * fixes * lint * write new test that makes tf graphs of random algos; fix some bugs it caught * add —delete flag to rcall upload-code command * lint * get cifar10 lazily for testing purposes * disable codegen ci tests for now * clean up rl_ops * rename spec classes * td3 with identity test * identity tests without gin files * remove gin.configurable from AlgoAgent * comments about reduction in rl_ops_cc * address @pzhokhov comments * fix tests * more linting * better tests * clean up filtering a bit * fix concat * delayed logger configuration (#208) * delayed logger configuration * fix typo * setters and getters for Logger.DEFAULT as well * do away with fancy property stuff - unable to get it to work with class level methods * grammar and spaces * spaces * use get_current function instead of reading Logger.CURRENT * autopep8 * disable mpi in subprocesses (#213) * lazy_mpi load * cleanups * more lazy mpi * don't pretend that class is a module, just use it as a class * mass-replace mpi4py imports * flake8 * fix previous lazy_mpi imports * silly recursion * try os.environ hack * better prefix test, work with mpich * restored MPI imports * removed commented import in test_with_mpi * restored codegen from master * remove lazy mpi * restored changes from rl-algs * remove extra files * address Chris' comments * use spawn for shmem vec env as well (#2) (#219) * lazy_mpi load * cleanups * more lazy mpi * don't pretend that class is a module, just use it as a class * mass-replace mpi4py imports * flake8 * fix previous lazy_mpi imports * silly recursion * try os.environ hack * better prefix test, work with mpich * restored MPI imports * removed commented import in test_with_mpi * restored codegen from master * remove lazy mpi * restored changes from rl-algs * remove extra files * port mpi fix to shmem vec env * increase the mpi test default timeout * change humanoid hyperparameters, get rid of clip_Frac annealing, as it's apparently dangerous * remove clip_frac schedule from ppo2 * more timesteps in humanoid run * whitespace + RUN BENCHMARKS * baselines: export vecenvs from folder (#221) * baselines: export vecenvs from folder * put missing function back in * add missing imports * more imports * longer mpi timeout? * make default logger configuration the same as call to logger.configure() (#222) * Vecenv refactor (#223) * update karl util * restore pvi flag * change rcall auto cpu behavior, move gin.configurable, add os.makedirs * vecenv refactor * aux buf index fix * add num aux obs * reset level with enter * restore high difficulty flag * bugfix * restore train_coinrun.py * tweaks * renaming * renaming * better arguments handling * more options * options cleanup * game data refactor * more options * args for train_procgen * add close handler to interactive base class * use debug build if debug=True, fix range on aux_obs * add ProcGenEnv to __init__.py, add missing imports to procgen.py * export RemoveDictWrapper and build, update train_procgen.py, move assets download into env creation and replace init_assets_and_build with just build * fix formatting issues * only call global init once * fix path in setup.py * revert part of makefile * ignore IDE files and folders * vec remove dict * export VecRemoveDictObs * remove RemoveDictWrapper * remove IDE files * move shared .h and .cpp files to common folder, update build to use those, dedupe env.cpp * fix missing header * try unified build function * remove old scripts dir * add comment on build * upload libenv with render fixes * tell qthreads to die when we unload the library * pyglet.app.run is garbage * static fixes * whoops * actually vsync is on * cleanup * cleanup * extern C for libenv interface * parse util rcall arg * high difficulty fix * game type enums * ProcGenEnv subclasses * game type cleanup * unrecognized key * unrecognized game type * parse util reorg * args management * typo fix * GinParser * arg tweaks * tweak * restore start_level/num_levels setting * fix create_procgen_env interface * build fix * procgen args in init signature * fix * build fix * fix logger usage in ppo_metal/run_retro * removed unnecessary OrderedDict requirement in subproc_vec_env * flake8 fix * allow for non-mpi tests * mpi test fixes * flake8; removed special logic for discrete spaces in dummy_vec_env * remove forked argument in front of tests - does not play nicely with subprocvecenv in spawned processes; analog of forked in ddpg/test_smoke * Everyrl initial commit & a few minor baselines changes (#226) * everyrl initial commit * add keep_buf argument to VecMonitor * logger changes: set_comm and fix to mpi_mean functionality * if filename not provided, don't create ResultsWriter * change variable syncing function to simplify its usage. now you should initialize from all mpi processes * everyrl coinrun changes * tf_distr changes, bugfix * get_one * bring back get_next to temporarily restore code * lint fixes * fix test * rename profile function * rename gaussian * fix coinrun training script * change random seeding to work with new gym version (#231) * change random seeding to work with new gym version * move seeding to seed() method * fix mnistenv * actually try some of the tests before pushing * more deterministic fixed seq * misc changes to vecenvs and run.py for benchmarks (#236) * misc changes to vecenvs and run.py for benchmarks * dont seed global gen * update more references to assert_venvs_equal * Rl19 (#232) * everyrl initial commit * add keep_buf argument to VecMonitor * logger changes: set_comm and fix to mpi_mean functionality * if filename not provided, don't create ResultsWriter * change variable syncing function to simplify its usage. now you should initialize from all mpi processes * everyrl coinrun changes * tf_distr changes, bugfix * get_one * bring back get_next to temporarily restore code * lint fixes * fix test * rename profile function * rename gaussian * fix coinrun training script * rl19 * remove everyrl dir which appeared in the merge for some reason * readme * fiddle with ddpg * make ddpg work * steps_total argument * gpu count * clean up hyperparams and shape math * logging + saving * configuration stuff * fixes, smoke tests * fix stats * make load_results return dicts -- easier to create the same kind of objects with some other mechanism for passing to downstream functions * benchmarks * fix tests * add dqn to tests, fix it * minor * turned annotated transformer (pytorch) into a script * more refactoring * jax stuff * cluster * minor * copy & paste alec code * sign error * add huber, rename some parameters, snapshotting off by default * remove jax stuff * minor * move maze env * minor * remove trailing spaces * remove trailing space * lint * fix test breakage due to gym update * rename function * move maze back to codegen * get recurrent ppo working * enable both lstm and gru * script to print table of benchmark results * various * fix dqn * add fixup initializer, remove lastrew * organize logging stats * fix silly bug * refactor models * fix mpi usage * check sync * minor * change vf coef, hps * clean up slicing in ppo * minor fixes * caching transformer * docstrings * xf fixes * get rid of 'B' and 'BT' arguments * minor * transformer example * remove output_kind from base class until we have a better idea how to use it * add comments, revert maze stuff * flake8 * codegen lint * fix codegen tests * responded to peter's comments * lint fixes * minor changes to baselines (#243) * minor changes to baselines * fix spaces reference * remove flake8 disable comments and fix import * okay maybe don't add spec to vec_env * Merge branch 'master' of github.com:openai/games the commit. * flake8 complaints in baselines/her
439 lines
16 KiB
Python
439 lines
16 KiB
Python
import numpy as np
|
|
import tensorflow as tf # pylint: ignore-module
|
|
import copy
|
|
import os
|
|
import functools
|
|
import collections
|
|
import multiprocessing
|
|
|
|
def switch(condition, then_expression, else_expression):
|
|
"""Switches between two operations depending on a scalar value (int or bool).
|
|
Note that both `then_expression` and `else_expression`
|
|
should be symbolic tensors of the *same shape*.
|
|
|
|
# Arguments
|
|
condition: scalar tensor.
|
|
then_expression: TensorFlow operation.
|
|
else_expression: TensorFlow operation.
|
|
"""
|
|
x_shape = copy.copy(then_expression.get_shape())
|
|
x = tf.cond(tf.cast(condition, 'bool'),
|
|
lambda: then_expression,
|
|
lambda: else_expression)
|
|
x.set_shape(x_shape)
|
|
return x
|
|
|
|
# ================================================================
|
|
# Extras
|
|
# ================================================================
|
|
|
|
def lrelu(x, leak=0.2):
|
|
f1 = 0.5 * (1 + leak)
|
|
f2 = 0.5 * (1 - leak)
|
|
return f1 * x + f2 * abs(x)
|
|
|
|
# ================================================================
|
|
# Mathematical utils
|
|
# ================================================================
|
|
|
|
def huber_loss(x, delta=1.0):
|
|
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
|
|
return tf.where(
|
|
tf.abs(x) < delta,
|
|
tf.square(x) * 0.5,
|
|
delta * (tf.abs(x) - 0.5 * delta)
|
|
)
|
|
|
|
# ================================================================
|
|
# Global session
|
|
# ================================================================
|
|
|
|
def get_session(config=None):
|
|
"""Get default session or create one with a given config"""
|
|
sess = tf.get_default_session()
|
|
if sess is None:
|
|
sess = make_session(config=config, make_default=True)
|
|
return sess
|
|
|
|
def make_session(config=None, num_cpu=None, make_default=False, graph=None):
|
|
"""Returns a session that will use <num_cpu> CPU's only"""
|
|
if num_cpu is None:
|
|
num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
|
|
if config is None:
|
|
config = tf.ConfigProto(
|
|
allow_soft_placement=True,
|
|
inter_op_parallelism_threads=num_cpu,
|
|
intra_op_parallelism_threads=num_cpu)
|
|
config.gpu_options.allow_growth = True
|
|
|
|
if make_default:
|
|
return tf.InteractiveSession(config=config, graph=graph)
|
|
else:
|
|
return tf.Session(config=config, graph=graph)
|
|
|
|
def single_threaded_session():
|
|
"""Returns a session which will only use a single CPU"""
|
|
return make_session(num_cpu=1)
|
|
|
|
def in_session(f):
|
|
@functools.wraps(f)
|
|
def newfunc(*args, **kwargs):
|
|
with tf.Session():
|
|
f(*args, **kwargs)
|
|
return newfunc
|
|
|
|
ALREADY_INITIALIZED = set()
|
|
|
|
def initialize():
|
|
"""Initialize all the uninitialized variables in the global scope."""
|
|
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
|
|
get_session().run(tf.variables_initializer(new_variables))
|
|
ALREADY_INITIALIZED.update(new_variables)
|
|
|
|
# ================================================================
|
|
# Model components
|
|
# ================================================================
|
|
|
|
def normc_initializer(std=1.0, axis=0):
|
|
def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613
|
|
out = np.random.randn(*shape).astype(dtype.as_numpy_dtype)
|
|
out *= std / np.sqrt(np.square(out).sum(axis=axis, keepdims=True))
|
|
return tf.constant(out)
|
|
return _initializer
|
|
|
|
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None,
|
|
summary_tag=None):
|
|
with tf.variable_scope(name):
|
|
stride_shape = [1, stride[0], stride[1], 1]
|
|
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
|
|
|
|
# there are "num input feature maps * filter height * filter width"
|
|
# inputs to each hidden unit
|
|
fan_in = intprod(filter_shape[:3])
|
|
# each unit in the lower layer receives a gradient from:
|
|
# "num output feature maps * filter height * filter width" /
|
|
# pooling size
|
|
fan_out = intprod(filter_shape[:2]) * num_filters
|
|
# initialize weights with random weights
|
|
w_bound = np.sqrt(6. / (fan_in + fan_out))
|
|
|
|
w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
|
|
collections=collections)
|
|
b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(),
|
|
collections=collections)
|
|
|
|
if summary_tag is not None:
|
|
tf.summary.image(summary_tag,
|
|
tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]),
|
|
[2, 0, 1, 3]),
|
|
max_images=10)
|
|
|
|
return tf.nn.conv2d(x, w, stride_shape, pad) + b
|
|
|
|
# ================================================================
|
|
# Theano-like Function
|
|
# ================================================================
|
|
|
|
def function(inputs, outputs, updates=None, givens=None):
|
|
"""Just like Theano function. Take a bunch of tensorflow placeholders and expressions
|
|
computed based on those placeholders and produces f(inputs) -> outputs. Function f takes
|
|
values to be fed to the input's placeholders and produces the values of the expressions
|
|
in outputs.
|
|
|
|
Input values can be passed in the same order as inputs or can be provided as kwargs based
|
|
on placeholder name (passed to constructor or accessible via placeholder.op.name).
|
|
|
|
Example:
|
|
x = tf.placeholder(tf.int32, (), name="x")
|
|
y = tf.placeholder(tf.int32, (), name="y")
|
|
z = 3 * x + 2 * y
|
|
lin = function([x, y], z, givens={y: 0})
|
|
|
|
with single_threaded_session():
|
|
initialize()
|
|
|
|
assert lin(2) == 6
|
|
assert lin(x=3) == 9
|
|
assert lin(2, 2) == 10
|
|
assert lin(x=2, y=3) == 12
|
|
|
|
Parameters
|
|
----------
|
|
inputs: [tf.placeholder, tf.constant, or object with make_feed_dict method]
|
|
list of input arguments
|
|
outputs: [tf.Variable] or tf.Variable
|
|
list of outputs or a single output to be returned from function. Returned
|
|
value will also have the same shape.
|
|
updates: [tf.Operation] or tf.Operation
|
|
list of update functions or single update function that will be run whenever
|
|
the function is called. The return is ignored.
|
|
|
|
"""
|
|
if isinstance(outputs, list):
|
|
return _Function(inputs, outputs, updates, givens=givens)
|
|
elif isinstance(outputs, (dict, collections.OrderedDict)):
|
|
f = _Function(inputs, outputs.values(), updates, givens=givens)
|
|
return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs)))
|
|
else:
|
|
f = _Function(inputs, [outputs], updates, givens=givens)
|
|
return lambda *args, **kwargs: f(*args, **kwargs)[0]
|
|
|
|
|
|
class _Function(object):
|
|
def __init__(self, inputs, outputs, updates, givens):
|
|
for inpt in inputs:
|
|
if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
|
|
assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
|
|
self.inputs = inputs
|
|
self.input_names = {inp.name.split("/")[-1].split(":")[0]: inp for inp in inputs}
|
|
updates = updates or []
|
|
self.update_group = tf.group(*updates)
|
|
self.outputs_update = list(outputs) + [self.update_group]
|
|
self.givens = {} if givens is None else givens
|
|
|
|
def _feed_input(self, feed_dict, inpt, value):
|
|
if hasattr(inpt, 'make_feed_dict'):
|
|
feed_dict.update(inpt.make_feed_dict(value))
|
|
else:
|
|
feed_dict[inpt] = adjust_shape(inpt, value)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
assert len(args) + len(kwargs) <= len(self.inputs), "Too many arguments provided"
|
|
feed_dict = {}
|
|
# Update feed dict with givens.
|
|
for inpt in self.givens:
|
|
feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt]))
|
|
# Update the args
|
|
for inpt, value in zip(self.inputs, args):
|
|
self._feed_input(feed_dict, inpt, value)
|
|
for inpt_name, value in kwargs.items():
|
|
self._feed_input(feed_dict, self.input_names[inpt_name], value)
|
|
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
|
return results
|
|
|
|
# ================================================================
|
|
# Flat vectors
|
|
# ================================================================
|
|
|
|
def var_shape(x):
|
|
out = x.get_shape().as_list()
|
|
assert all(isinstance(a, int) for a in out), \
|
|
"shape function assumes that shape is fully known"
|
|
return out
|
|
|
|
def numel(x):
|
|
return intprod(var_shape(x))
|
|
|
|
def intprod(x):
|
|
return int(np.prod(x))
|
|
|
|
def flatgrad(loss, var_list, clip_norm=None):
|
|
grads = tf.gradients(loss, var_list)
|
|
if clip_norm is not None:
|
|
grads = [tf.clip_by_norm(grad, clip_norm=clip_norm) for grad in grads]
|
|
return tf.concat(axis=0, values=[
|
|
tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)])
|
|
for (v, grad) in zip(var_list, grads)
|
|
])
|
|
|
|
class SetFromFlat(object):
|
|
def __init__(self, var_list, dtype=tf.float32):
|
|
assigns = []
|
|
shapes = list(map(var_shape, var_list))
|
|
total_size = np.sum([intprod(shape) for shape in shapes])
|
|
|
|
self.theta = theta = tf.placeholder(dtype, [total_size])
|
|
start = 0
|
|
assigns = []
|
|
for (shape, v) in zip(shapes, var_list):
|
|
size = intprod(shape)
|
|
assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], shape)))
|
|
start += size
|
|
self.op = tf.group(*assigns)
|
|
|
|
def __call__(self, theta):
|
|
tf.get_default_session().run(self.op, feed_dict={self.theta: theta})
|
|
|
|
class GetFlat(object):
|
|
def __init__(self, var_list):
|
|
self.op = tf.concat(axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list])
|
|
|
|
def __call__(self):
|
|
return tf.get_default_session().run(self.op)
|
|
|
|
def flattenallbut0(x):
|
|
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
|
|
|
|
# =============================================================
|
|
# TF placeholders management
|
|
# ============================================================
|
|
|
|
_PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape)
|
|
|
|
def get_placeholder(name, dtype, shape):
|
|
if name in _PLACEHOLDER_CACHE:
|
|
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
|
|
if out.graph == tf.get_default_graph():
|
|
assert dtype1 == dtype and shape1 == shape, \
|
|
'Placeholder with name {} has already been registered and has shape {}, different from requested {}'.format(name, shape1, shape)
|
|
return out
|
|
|
|
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
|
|
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
|
|
return out
|
|
|
|
def get_placeholder_cached(name):
|
|
return _PLACEHOLDER_CACHE[name][0]
|
|
|
|
|
|
|
|
# ================================================================
|
|
# Diagnostics
|
|
# ================================================================
|
|
|
|
def display_var_info(vars):
|
|
from baselines import logger
|
|
count_params = 0
|
|
for v in vars:
|
|
name = v.name
|
|
if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
|
|
v_params = np.prod(v.shape.as_list())
|
|
count_params += v_params
|
|
if "/b:" in name or "/bias" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
|
|
logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape)))
|
|
|
|
logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))
|
|
|
|
|
|
def get_available_gpus():
|
|
# recipe from here:
|
|
# https://stackoverflow.com/questions/38559755/how-to-get-current-available-gpus-in-tensorflow?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
|
|
|
|
from tensorflow.python.client import device_lib
|
|
local_device_protos = device_lib.list_local_devices()
|
|
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|
|
|
|
# ================================================================
|
|
# Saving variables
|
|
# ================================================================
|
|
|
|
def load_state(fname, sess=None):
|
|
from baselines import logger
|
|
logger.warn('load_state method is deprecated, please use load_variables instead')
|
|
sess = sess or get_session()
|
|
saver = tf.train.Saver()
|
|
saver.restore(tf.get_default_session(), fname)
|
|
|
|
def save_state(fname, sess=None):
|
|
from baselines import logger
|
|
logger.warn('save_state method is deprecated, please use save_variables instead')
|
|
sess = sess or get_session()
|
|
dirname = os.path.dirname(fname)
|
|
if any(dirname):
|
|
os.makedirs(dirname, exist_ok=True)
|
|
saver = tf.train.Saver()
|
|
saver.save(tf.get_default_session(), fname)
|
|
|
|
# The methods above and below are clearly doing the same thing, and in a rather similar way
|
|
# TODO: ensure there is no subtle differences and remove one
|
|
|
|
def save_variables(save_path, variables=None, sess=None):
|
|
import joblib
|
|
sess = sess or get_session()
|
|
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
|
|
|
ps = sess.run(variables)
|
|
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
|
dirname = os.path.dirname(save_path)
|
|
if any(dirname):
|
|
os.makedirs(dirname, exist_ok=True)
|
|
joblib.dump(save_dict, save_path)
|
|
|
|
def load_variables(load_path, variables=None, sess=None):
|
|
import joblib
|
|
sess = sess or get_session()
|
|
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
|
|
|
loaded_params = joblib.load(os.path.expanduser(load_path))
|
|
restores = []
|
|
if isinstance(loaded_params, list):
|
|
assert len(loaded_params) == len(variables), 'number of variables loaded mismatches len(variables)'
|
|
for d, v in zip(loaded_params, variables):
|
|
restores.append(v.assign(d))
|
|
else:
|
|
for v in variables:
|
|
restores.append(v.assign(loaded_params[v.name]))
|
|
|
|
sess.run(restores)
|
|
|
|
# ================================================================
|
|
# Shape adjustment for feeding into tf placeholders
|
|
# ================================================================
|
|
def adjust_shape(placeholder, data):
|
|
'''
|
|
adjust shape of the data to the shape of the placeholder if possible.
|
|
If shape is incompatible, AssertionError is thrown
|
|
|
|
Parameters:
|
|
placeholder tensorflow input placeholder
|
|
|
|
data input data to be (potentially) reshaped to be fed into placeholder
|
|
|
|
Returns:
|
|
reshaped data
|
|
'''
|
|
|
|
if not isinstance(data, np.ndarray) and not isinstance(data, list):
|
|
return data
|
|
if isinstance(data, list):
|
|
data = np.array(data)
|
|
|
|
placeholder_shape = [x or -1 for x in placeholder.shape.as_list()]
|
|
|
|
assert _check_shape(placeholder_shape, data.shape), \
|
|
'Shape of data {} is not compatible with shape of the placeholder {}'.format(data.shape, placeholder_shape)
|
|
|
|
return np.reshape(data, placeholder_shape)
|
|
|
|
|
|
def _check_shape(placeholder_shape, data_shape):
|
|
''' check if two shapes are compatible (i.e. differ only by dimensions of size 1, or by the batch dimension)'''
|
|
|
|
return True
|
|
squeezed_placeholder_shape = _squeeze_shape(placeholder_shape)
|
|
squeezed_data_shape = _squeeze_shape(data_shape)
|
|
|
|
for i, s_data in enumerate(squeezed_data_shape):
|
|
s_placeholder = squeezed_placeholder_shape[i]
|
|
if s_placeholder != -1 and s_data != s_placeholder:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _squeeze_shape(shape):
|
|
return [x for x in shape if x != 1]
|
|
|
|
# ================================================================
|
|
# Tensorboard interfacing
|
|
# ================================================================
|
|
|
|
def launch_tensorboard_in_background(log_dir):
|
|
'''
|
|
To log the Tensorflow graph when using rl-algs
|
|
algorithms, you can run the following code
|
|
in your main script:
|
|
import threading, time
|
|
def start_tensorboard(session):
|
|
time.sleep(10) # Wait until graph is setup
|
|
tb_path = osp.join(logger.get_dir(), 'tb')
|
|
summary_writer = tf.summary.FileWriter(tb_path, graph=session.graph)
|
|
summary_op = tf.summary.merge_all()
|
|
launch_tensorboard_in_background(tb_path)
|
|
session = tf.get_default_session()
|
|
t = threading.Thread(target=start_tensorboard, args=([session]))
|
|
t.start()
|
|
'''
|
|
import subprocess
|
|
subprocess.Popen(['tensorboard', '--logdir', log_dir])
|