Files
baselines/baselines/her/util.py
2018-05-16 12:00:23 -07:00

141 lines
3.9 KiB
Python

import os
import subprocess
import sys
import importlib
import inspect
import functools
import tensorflow as tf
import numpy as np
from baselines.common import tf_util as U
def store_args(method):
"""Stores provided method args as instance attributes.
"""
argspec = inspect.getfullargspec(method)
defaults = {}
if argspec.defaults is not None:
defaults = dict(
zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
if argspec.kwonlydefaults is not None:
defaults.update(argspec.kwonlydefaults)
arg_names = argspec.args[1:]
@functools.wraps(method)
def wrapper(*positional_args, **keyword_args):
self = positional_args[0]
# Get default arg values
args = defaults.copy()
# Add provided arg values
for name, value in zip(arg_names, positional_args[1:]):
args[name] = value
args.update(keyword_args)
self.__dict__.update(args)
return method(*positional_args, **keyword_args)
return wrapper
def import_function(spec):
"""Import a function identified by a string like "pkg.module:fn_name".
"""
mod_name, fn_name = spec.split(':')
module = importlib.import_module(mod_name)
fn = getattr(module, fn_name)
return fn
def flatten_grads(var_list, grads):
"""Flattens a variables and their gradients.
"""
return tf.concat([tf.reshape(grad, [U.numel(v)])
for (v, grad) in zip(var_list, grads)], 0)
def nn(input, layers_sizes, reuse=None, flatten=False, name=""):
"""Creates a simple neural network
"""
for i, size in enumerate(layers_sizes):
activation = tf.nn.relu if i < len(layers_sizes) - 1 else None
input = tf.layers.dense(inputs=input,
units=size,
kernel_initializer=tf.contrib.layers.xavier_initializer(),
reuse=reuse,
name=name + '_' + str(i))
if activation:
input = activation(input)
if flatten:
assert layers_sizes[-1] == 1
input = tf.reshape(input, [-1])
return input
def install_mpi_excepthook():
import sys
from mpi4py import MPI
old_hook = sys.excepthook
def new_hook(a, b, c):
old_hook(a, b, c)
sys.stdout.flush()
sys.stderr.flush()
MPI.COMM_WORLD.Abort()
sys.excepthook = new_hook
def mpi_fork(n, extra_mpi_args=[]):
"""Re-launches the current script with workers
Returns "parent" for original parent, "child" for MPI children
"""
if n <= 1:
return "child"
if os.getenv("IN_MPI") is None:
env = os.environ.copy()
env.update(
MKL_NUM_THREADS="1",
OMP_NUM_THREADS="1",
IN_MPI="1"
)
# "-bind-to core" is crucial for good performance
args = ["mpirun", "-np", str(n)] + \
extra_mpi_args + \
[sys.executable]
args += sys.argv
subprocess.check_call(args, env=env)
return "parent"
else:
install_mpi_excepthook()
return "child"
def convert_episode_to_batch_major(episode):
"""Converts an episode to have the batch dimension in the major (first)
dimension.
"""
episode_batch = {}
for key in episode.keys():
val = np.array(episode[key]).copy()
# make inputs batch-major instead of time-major
episode_batch[key] = val.swapaxes(0, 1)
return episode_batch
def transitions_in_episode_batch(episode_batch):
"""Number of transitions in a given episode batch.
"""
shape = episode_batch['u'].shape
return shape[0] * shape[1]
def reshape_for_broadcasting(source, target):
"""Reshapes a tensor (source) to have the correct shape and dtype of the target
before broadcasting it with MPI.
"""
dim = len(target.get_shape())
shape = ([1] * (dim - 1)) + [-1]
return tf.reshape(tf.cast(source, target.dtype), shape)