release Internal changes (#800)
* joshim5 changes (width and height to WarpFrame wrapper) * match network output with action distribution via a linear layer only if necessary (#167) * support color vs. grayscale option in WarpFrame wrapper (#166) * support color vs. grayscale option in WarpFrame wrapper * Support color in other wrappers * Updated per Peters suggestions * fixing test failures * ppo2 with microbatches (#168) * pass microbatch_size to the model during construction * microbatch fixes and test (#169) * microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix * Peterz joshim5 subclass ppo2 model (#170) * microbatch fixes and test * tiny cleanup * added assertions to the test * vpg-related fix * subclassing the model to make microbatched version of model WIP * made microbatched model a subclass of ppo2 Model * flake8 complaint * mpi-less ppo2 (resolving merge conflict) * flake8 and mpi4py imports in ppo2/model.py * more un-mpying * merge master * updates to the benchmark viewer code + autopep8 (#184) * viz docs and syntactic sugar wip * update viewer yaml to use persistent volume claims * move plot_util to baselines.common, update links * use 1Tb hard drive for results viewer * small updates to benchmark vizualizer code * autopep8 * autopep8 * any folder can be a benchmark * massage games image a little bit * fixed --preload option in app.py * remove preload from run_viewer.sh * remove pdb breakpoints * update bench-viewer.yaml * fixed bug (#185) * fixed bug it's wrong to do the else statement, because no other nodes would start. * changed the fix slightly * Refactor her phase 1 (#194) * add monitor to the rollout envs in her RUN BENCHMARKS her * Slice -> Slide in her benchmarks RUN BENCHMARKS her * run her benchmark for 200 epochs * dummy commit to RUN BENCHMARKS her * her benchmark for 500 epochs RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * add num_timesteps to her benchmark to be compatible with viewer RUN BENCHMARKS her * disable saving of policies in her benchmark RUN BENCHMARKS her * run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch * run fetch benchmarks with ppo2 and ddpg RUN BENCHMARKS Fetch * launcher refactor wip * wip * her works on FetchReach * her runner refactor RUN BENCHMARKS Fetch1M * unit test for her * fixing warnings in mpi_average in her, skip test_fetchreach if mujoco is not present * pickle-based serialization in her * remove extra import from subproc_vec_env.py * investigating differences in rollout.py * try with old rollout code RUN BENCHMARKS her * temporarily use DummyVecEnv in cmd_util.py RUN BENCHMARKS her * dummy commit to RUN BENCHMARKS her * set info_values in rollout worker in her RUN BENCHMARKS her * bug in rollout_new.py RUN BENCHMARKS her * fixed bug in rollout_new.py RUN BENCHMARKS her * do not use last step because vecenv calls reset and returns obs after reset RUN BENCHMARKS her * updated buffer sizes RUN BENCHMARKS her * fixed loading/saving via joblib * dust off learning from demonstrations in HER, docs, refactor * add deprecation notice on her play and plot files * address comments by Matthias * 1.5 months of codegen changes (#196) * play with resnet * feed_dict version * coinrun prob and more stats * fixes to get_choices_specs & hp search * minor prob fixes * minor fixes * minor * alternative version of rl_algo stuff * pylint fixes * fix bugs, move node_filters to soup * changed how get_algo works * change how get_algo works, probably broke all tests * continue previous refactor * get eval_agent running again * fixing tests * fix tests * fix more tests * clean up cma stuff * fix experiment * minor changes to eval_agent to make ppo_metal use gpu * make dict space work * modify mac makefile to use conda * recurrent layers * play with bn and resnets * minor hp changes * minor * got rid of use_fb argument and jtft (joint-train-fine-tune) functionality built test phase directly into AlgoProb * make new rl algos generateable * pylint; start fixing tests * fixing tests * more test fixes * pylint * fix search * work on search * hack around infinite loop caused by scan * algo search fixes * misc changes for search expt * enable annealing, overriding options of Op * pylint fixes * identity op * achieve use_last_output through masking so it automatically works in other distributions * fix tests * minor * discrete * use_last_output to be just a preference, not a hard constraint * pred delay, pruning * require nontrivial inputs * aliases for get_sm * add probname to probs * fixes * small fixes * fix tests * fix tests * fix tests * minor * test scripts * dualgru network improvements * minor * work on mysterious bugs * rcall gpu-usage command for kube * use cache dir that’s not in code folder, so that it doesn’t get removed by rcall code rsync * add power mode to gpu usage * make sure train/test actually different * remove VR for now * minor fixes * simplify soln_db * minor * big refactor of mpi eda * improve mpieda for multitask * - get rid of timelimit hack - add __del__ to cleanup SubprocVecEnv * get multitask working better * fixes * working on atari, various * annotate ops with whether they’re parametrized * minor * gym version * rand atari prob * minor * SolnDb bugfix and name change * pyspy script * switch conv layers * fix roboschool/bullet3 * nenvs assertion * fix rand atari * get rid of blanket exception catching fix soln_db bug * fix rand_atari * dynamic routing as cmdline arg * slight modifications to test_mpi_map and pyspy-all * max_tries argument for run_until_successs * dedup option in train_mle * simplify soln_db * increase atari horizon for 1 experiment * start implementing reward increment * ent multiplier * create cc dsl other misc fixes * cc ops * q_func -> qs in rl_algos_cc.py * fix PredictDistr * rl_ops_cc fixes, MakeAction op * augment algo agent to support cc stuff * work on ddpg experiments * fix blocking temporarily change logger * allow layer scaling * pylint fixes * spawn_method * isolate ddpg hacks * improve pruning * use spawn for subproc * remove use of python -c in rcall * fix pylint warning * fix static * maybe fix local backend * switch to DummyVecEnv * making some fixes via pylint * pylint fixes * fixing tests * fix tests * fix tests * write scaffolding for SSL in Codegen * logger fix * fix error * add EMA op to sl_ops * save many changes * save * add upsampler * add sl ops, enhance state machine * get ssl search working — some gross hacking * fix session/graph issue * fix importing * work on mle * - scale embeddings in gru model - better exception handling in sl_prob - use emas for test/val - use non-contrib batch_norm layer * improve logging * option to average before dumping in logger * default arguments, etc * new ddpg and identity test * concat fix * minor * move realistic ssl stuff to third-party (underscore to dash) * fixes * remove realistic_ssl_evaluation * pylint fixes * use gym master * try again * pass around args without gin * fix tests * separate line to install gym * rename failing tests that should be ignored * add data aug * ssl improvements * use fixed time limit * try to fix baselines tests * add score_floor, max_walltime, fiddle with lr decay * realistic_ssl * autopep8 * various ssl - enable blocking grad for simplification - kl - multiple final prediction * fix pruning * misc ssl stuff * bring back linear schedule, don’t use allgather for collecting stats (i’ve been getting nondeterministic errors from the old code) * save/load weights in SSL, big stepsize * cleanup SslProb * fix * get rid of kl coef * fix simplification, lower lr * search over hps * minor fixes * minor * static analysis * move files and rename things for improved consistency. still broken, and just saving before making nontrivial changes * various * make tests pass * move coinrun_train to codegen since it depends on codegen * fixes * pylint fixes * improve tests fix some things * improve tests * lint * fix up db_info.py, tests * mostly restore master version of envs directory, except for makefile changes * fix tests * improve printing * minor fixes * fix fixmes * pruning test * fixes * lint * write new test that makes tf graphs of random algos; fix some bugs it caught * add —delete flag to rcall upload-code command * lint * get cifar10 lazily for testing purposes * disable codegen ci tests for now * clean up rl_ops * rename spec classes * td3 with identity test * identity tests without gin files * remove gin.configurable from AlgoAgent * comments about reduction in rl_ops_cc * address @pzhokhov comments * fix tests * more linting * better tests * clean up filtering a bit * fix concat * delayed logger configuration (#208) * delayed logger configuration * fix typo * setters and getters for Logger.DEFAULT as well * do away with fancy property stuff - unable to get it to work with class level methods * grammar and spaces * spaces * use get_current function instead of reading Logger.CURRENT * autopep8 * disable mpi in subprocesses (#213) * lazy_mpi load * cleanups * more lazy mpi * don't pretend that class is a module, just use it as a class * mass-replace mpi4py imports * flake8 * fix previous lazy_mpi imports * silly recursion * try os.environ hack * better prefix test, work with mpich * restored MPI imports * removed commented import in test_with_mpi * restored codegen from master * remove lazy mpi * restored changes from rl-algs * remove extra files * address Chris' comments * use spawn for shmem vec env as well (#2) (#219) * lazy_mpi load * cleanups * more lazy mpi * don't pretend that class is a module, just use it as a class * mass-replace mpi4py imports * flake8 * fix previous lazy_mpi imports * silly recursion * try os.environ hack * better prefix test, work with mpich * restored MPI imports * removed commented import in test_with_mpi * restored codegen from master * remove lazy mpi * restored changes from rl-algs * remove extra files * port mpi fix to shmem vec env * increase the mpi test default timeout * change humanoid hyperparameters, get rid of clip_Frac annealing, as it's apparently dangerous * remove clip_frac schedule from ppo2 * more timesteps in humanoid run * whitespace + RUN BENCHMARKS * baselines: export vecenvs from folder (#221) * baselines: export vecenvs from folder * put missing function back in * add missing imports * more imports * longer mpi timeout? * make default logger configuration the same as call to logger.configure() (#222) * Vecenv refactor (#223) * update karl util * restore pvi flag * change rcall auto cpu behavior, move gin.configurable, add os.makedirs * vecenv refactor * aux buf index fix * add num aux obs * reset level with enter * restore high difficulty flag * bugfix * restore train_coinrun.py * tweaks * renaming * renaming * better arguments handling * more options * options cleanup * game data refactor * more options * args for train_procgen * add close handler to interactive base class * use debug build if debug=True, fix range on aux_obs * add ProcGenEnv to __init__.py, add missing imports to procgen.py * export RemoveDictWrapper and build, update train_procgen.py, move assets download into env creation and replace init_assets_and_build with just build * fix formatting issues * only call global init once * fix path in setup.py * revert part of makefile * ignore IDE files and folders * vec remove dict * export VecRemoveDictObs * remove RemoveDictWrapper * remove IDE files * move shared .h and .cpp files to common folder, update build to use those, dedupe env.cpp * fix missing header * try unified build function * remove old scripts dir * add comment on build * upload libenv with render fixes * tell qthreads to die when we unload the library * pyglet.app.run is garbage * static fixes * whoops * actually vsync is on * cleanup * cleanup * extern C for libenv interface * parse util rcall arg * high difficulty fix * game type enums * ProcGenEnv subclasses * game type cleanup * unrecognized key * unrecognized game type * parse util reorg * args management * typo fix * GinParser * arg tweaks * tweak * restore start_level/num_levels setting * fix create_procgen_env interface * build fix * procgen args in init signature * fix * build fix * fix logger usage in ppo_metal/run_retro * removed unnecessary OrderedDict requirement in subproc_vec_env * flake8 fix * allow for non-mpi tests * mpi test fixes * flake8; removed special logic for discrete spaces in dummy_vec_env * remove forked argument in front of tests - does not play nicely with subprocvecenv in spawned processes; analog of forked in ddpg/test_smoke * Everyrl initial commit & a few minor baselines changes (#226) * everyrl initial commit * add keep_buf argument to VecMonitor * logger changes: set_comm and fix to mpi_mean functionality * if filename not provided, don't create ResultsWriter * change variable syncing function to simplify its usage. now you should initialize from all mpi processes * everyrl coinrun changes * tf_distr changes, bugfix * get_one * bring back get_next to temporarily restore code * lint fixes * fix test * rename profile function * rename gaussian * fix coinrun training script * change random seeding to work with new gym version (#231) * change random seeding to work with new gym version * move seeding to seed() method * fix mnistenv * actually try some of the tests before pushing * more deterministic fixed seq * misc changes to vecenvs and run.py for benchmarks (#236) * misc changes to vecenvs and run.py for benchmarks * dont seed global gen * update more references to assert_venvs_equal * Rl19 (#232) * everyrl initial commit * add keep_buf argument to VecMonitor * logger changes: set_comm and fix to mpi_mean functionality * if filename not provided, don't create ResultsWriter * change variable syncing function to simplify its usage. now you should initialize from all mpi processes * everyrl coinrun changes * tf_distr changes, bugfix * get_one * bring back get_next to temporarily restore code * lint fixes * fix test * rename profile function * rename gaussian * fix coinrun training script * rl19 * remove everyrl dir which appeared in the merge for some reason * readme * fiddle with ddpg * make ddpg work * steps_total argument * gpu count * clean up hyperparams and shape math * logging + saving * configuration stuff * fixes, smoke tests * fix stats * make load_results return dicts -- easier to create the same kind of objects with some other mechanism for passing to downstream functions * benchmarks * fix tests * add dqn to tests, fix it * minor * turned annotated transformer (pytorch) into a script * more refactoring * jax stuff * cluster * minor * copy & paste alec code * sign error * add huber, rename some parameters, snapshotting off by default * remove jax stuff * minor * move maze env * minor * remove trailing spaces * remove trailing space * lint * fix test breakage due to gym update * rename function * move maze back to codegen * get recurrent ppo working * enable both lstm and gru * script to print table of benchmark results * various * fix dqn * add fixup initializer, remove lastrew * organize logging stats * fix silly bug * refactor models * fix mpi usage * check sync * minor * change vf coef, hps * clean up slicing in ppo * minor fixes * caching transformer * docstrings * xf fixes * get rid of 'B' and 'BT' arguments * minor * transformer example * remove output_kind from base class until we have a better idea how to use it * add comments, revert maze stuff * flake8 * codegen lint * fix codegen tests * responded to peter's comments * lint fixes * minor changes to baselines (#243) * minor changes to baselines * fix spaces reference * remove flake8 disable comments and fix import * okay maybe don't add spec to vec_env * Merge branch 'master' of github.com:openai/games the commit. * flake8 complaints in baselines/her
This commit is contained in:
@@ -11,4 +11,4 @@ install:
|
||||
|
||||
script:
|
||||
- flake8 . --show-source --statistics
|
||||
- docker run baselines-test pytest -v --forked .
|
||||
- docker run baselines-test pytest -v .
|
||||
|
@@ -20,7 +20,7 @@ def register_benchmark(benchmark):
|
||||
if 'tasks' in benchmark:
|
||||
for t in benchmark['tasks']:
|
||||
if 'desc' not in t:
|
||||
t['desc'] = remove_version_re.sub('', t['env_id'])
|
||||
t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id')))
|
||||
_BENCHMARKS.append(benchmark)
|
||||
|
||||
|
||||
|
@@ -16,11 +16,13 @@ class Monitor(Wrapper):
|
||||
def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
|
||||
Wrapper.__init__(self, env=env)
|
||||
self.tstart = time.time()
|
||||
self.results_writer = ResultsWriter(
|
||||
filename,
|
||||
if filename:
|
||||
self.results_writer = ResultsWriter(filename,
|
||||
header={"t_start": time.time(), 'env_id' : env.spec and env.spec.id},
|
||||
extra_keys=reset_keywords + info_keywords
|
||||
)
|
||||
else:
|
||||
self.results_writer = None
|
||||
self.reset_keywords = reset_keywords
|
||||
self.info_keywords = info_keywords
|
||||
self.allow_early_resets = allow_early_resets
|
||||
@@ -68,8 +70,9 @@ class Monitor(Wrapper):
|
||||
self.episode_lengths.append(eplen)
|
||||
self.episode_times.append(time.time() - self.tstart)
|
||||
epinfo.update(self.current_reset_info)
|
||||
if self.results_writer:
|
||||
self.results_writer.write_row(epinfo)
|
||||
|
||||
assert isinstance(info, dict)
|
||||
if isinstance(info, dict):
|
||||
info['episode'] = epinfo
|
||||
|
||||
@@ -96,12 +99,9 @@ class LoadMonitorResultsError(Exception):
|
||||
|
||||
|
||||
class ResultsWriter(object):
|
||||
def __init__(self, filename=None, header='', extra_keys=()):
|
||||
def __init__(self, filename, header='', extra_keys=()):
|
||||
self.extra_keys = extra_keys
|
||||
if filename is None:
|
||||
self.f = None
|
||||
self.logger = None
|
||||
else:
|
||||
assert filename is not None
|
||||
if not filename.endswith(Monitor.EXT):
|
||||
if osp.isdir(filename):
|
||||
filename = osp.join(filename, Monitor.EXT)
|
||||
@@ -121,7 +121,6 @@ class ResultsWriter(object):
|
||||
self.f.flush()
|
||||
|
||||
|
||||
|
||||
def get_monitor_files(dir):
|
||||
return glob(osp.join(dir, "*" + Monitor.EXT))
|
||||
|
||||
|
@@ -6,6 +6,8 @@ import gym
|
||||
from gym import spaces
|
||||
import cv2
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
from .wrappers import TimeLimit
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env, noop_max=30):
|
||||
@@ -221,14 +223,13 @@ class LazyFrames(object):
|
||||
def __getitem__(self, i):
|
||||
return self._force()[i]
|
||||
|
||||
def make_atari(env_id, timelimit=True):
|
||||
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
|
||||
def make_atari(env_id, max_episode_steps=None):
|
||||
env = gym.make(env_id)
|
||||
if not timelimit:
|
||||
env = env.env
|
||||
assert 'NoFrameskip' in env.spec.id
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if max_episode_steps is not None:
|
||||
env = TimeLimit(env, max_episode_steps=max_episode_steps)
|
||||
return env
|
||||
|
||||
def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
|
||||
|
@@ -30,16 +30,19 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
seed = seed + 10000 * mpi_rank if seed is not None else None
|
||||
logger_dir = logger.get_dir()
|
||||
def make_thunk(rank):
|
||||
return lambda: make_env(
|
||||
env_id=env_id,
|
||||
env_type=env_type,
|
||||
subrank = rank,
|
||||
mpi_rank=mpi_rank,
|
||||
subrank=rank,
|
||||
seed=seed,
|
||||
reward_scale=reward_scale,
|
||||
gamestate=gamestate,
|
||||
flatten_dict_observations=flatten_dict_observations,
|
||||
wrapper_kwargs=wrapper_kwargs
|
||||
wrapper_kwargs=wrapper_kwargs,
|
||||
logger_dir=logger_dir
|
||||
)
|
||||
|
||||
set_global_seeds(seed)
|
||||
@@ -49,8 +52,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
return DummyVecEnv([make_thunk(start_index)])
|
||||
|
||||
|
||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None):
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
if env_type == 'atari':
|
||||
env = make_atari(env_id)
|
||||
@@ -67,7 +69,7 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate
|
||||
|
||||
env.seed(seed + subrank if seed is not None else None)
|
||||
env = Monitor(env,
|
||||
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
|
||||
if env_type == 'atari':
|
||||
@@ -134,6 +136,7 @@ def common_arg_parser():
|
||||
"""
|
||||
parser = arg_parser()
|
||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
||||
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
||||
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
||||
|
@@ -206,7 +206,8 @@ class CategoricalPd(Pd):
|
||||
class MultiCategoricalPd(Pd):
|
||||
def __init__(self, nvec, flat):
|
||||
self.flat = flat
|
||||
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1)))
|
||||
self.categoricals = list(map(CategoricalPd,
|
||||
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
|
||||
def flatparam(self):
|
||||
return self.flat
|
||||
def mode(self):
|
||||
|
@@ -17,9 +17,16 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||
num_tasks = self.comm.Get_size()
|
||||
buf = np.zeros(sum(sizes), np.float32)
|
||||
|
||||
sess = tf.get_default_session()
|
||||
assert sess is not None
|
||||
countholder = [0] # Counts how many times _collect_grads has been called
|
||||
stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
|
||||
def _collect_grads(flat_grad):
|
||||
self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
|
||||
np.divide(buf, float(num_tasks), out=buf)
|
||||
if countholder[0] % 100 == 0:
|
||||
check_synced(sess, self.comm, stat)
|
||||
countholder[0] += 1
|
||||
return buf
|
||||
|
||||
avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32)
|
||||
@@ -27,5 +34,13 @@ class MpiAdamOptimizer(tf.train.AdamOptimizer):
|
||||
avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
|
||||
avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
|
||||
for g, (_, v) in zip(avg_grads, grads_and_vars)]
|
||||
|
||||
return avg_grads_and_vars
|
||||
|
||||
def check_synced(sess, comm, tfstat):
|
||||
"""
|
||||
Check that 'tfstat' evaluates to the same thing on every MPI worker
|
||||
"""
|
||||
localval = sess.run(tfstat)
|
||||
vals = comm.gather(localval)
|
||||
if comm.rank == 0:
|
||||
assert all(val==vals[0] for val in vals[1:])
|
||||
|
@@ -1,9 +1,16 @@
|
||||
from collections import defaultdict
|
||||
from mpi4py import MPI
|
||||
import os, numpy as np
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
|
||||
def sync_from_root(sess, variables, comm=None):
|
||||
"""
|
||||
@@ -13,15 +20,10 @@ def sync_from_root(sess, variables, comm=None):
|
||||
variables: all parameter variables including optimizer's
|
||||
"""
|
||||
if comm is None: comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
for var in variables:
|
||||
if rank == 0:
|
||||
comm.Bcast(sess.run(var))
|
||||
else:
|
||||
import tensorflow as tf
|
||||
returned_var = np.empty(var.shape, dtype='float32')
|
||||
comm.Bcast(returned_var)
|
||||
sess.run(tf.assign(var, returned_var))
|
||||
values = comm.bcast(sess.run(variables))
|
||||
sess.run([tf.assign(var, val)
|
||||
for (var, val) in zip(variables, values)])
|
||||
|
||||
def gpu_count():
|
||||
"""
|
||||
@@ -34,13 +36,15 @@ def gpu_count():
|
||||
|
||||
def setup_mpi_gpus():
|
||||
"""
|
||||
Set CUDA_VISIBLE_DEVICES using MPI.
|
||||
Set CUDA_VISIBLE_DEVICES to MPI rank if not already set
|
||||
"""
|
||||
num_gpus = gpu_count()
|
||||
if num_gpus == 0:
|
||||
return
|
||||
local_rank, _ = get_local_rank_size(MPI.COMM_WORLD)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank % num_gpus)
|
||||
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
||||
if sys.platform == 'darwin': # This Assumes if you're on OSX you're just
|
||||
ids = [] # doing a smoke test and don't want GPUs
|
||||
else:
|
||||
lrank, _lsize = get_local_rank_size(MPI.COMM_WORLD)
|
||||
ids = [lrank]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, ids))
|
||||
|
||||
def get_local_rank_size(comm):
|
||||
"""
|
||||
@@ -81,6 +85,9 @@ def share_file(comm, path):
|
||||
comm.Barrier()
|
||||
|
||||
def dict_gather(comm, d, op='mean', assert_all_have_data=True):
|
||||
"""
|
||||
Perform a reduction operation over dicts
|
||||
"""
|
||||
if comm is None: return d
|
||||
alldicts = comm.allgather(d)
|
||||
size = comm.size
|
||||
@@ -99,3 +106,28 @@ def dict_gather(comm, d, op='mean', assert_all_have_data=True):
|
||||
else:
|
||||
assert 0, op
|
||||
return result
|
||||
|
||||
def mpi_weighted_mean(comm, local_name2valcount):
|
||||
"""
|
||||
Perform a weighted average over dicts that are each on a different node
|
||||
Input: local_name2valcount: dict mapping key -> (value, count)
|
||||
Returns: key -> mean
|
||||
"""
|
||||
all_name2valcount = comm.gather(local_name2valcount)
|
||||
if comm.rank == 0:
|
||||
name2sum = defaultdict(float)
|
||||
name2count = defaultdict(float)
|
||||
for n2vc in all_name2valcount:
|
||||
for (name, (val, count)) in n2vc.items():
|
||||
try:
|
||||
val = float(val)
|
||||
except ValueError:
|
||||
if comm.rank == 0:
|
||||
warnings.warn(f'WARNING: tried to compute mean on non-float {name}={val}')
|
||||
else:
|
||||
name2sum[name] += val * count
|
||||
name2count[name] += count
|
||||
return {name : name2sum[name] / name2count[name] for name in name2sum}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
@@ -1,25 +1,11 @@
|
||||
# flake8: noqa F403, F405
|
||||
from .atari_wrappers import *
|
||||
from collections import deque
|
||||
import cv2
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
from .atari_wrappers import WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame
|
||||
from .wrappers import TimeLimit
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
class TimeLimit(gym.Wrapper):
|
||||
def __init__(self, env, max_episode_steps=None):
|
||||
super(TimeLimit, self).__init__(env)
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._elapsed_steps = 0
|
||||
|
||||
def step(self, ac):
|
||||
observation, reward, done, info = self.env.step(ac)
|
||||
self._elapsed_steps += 1
|
||||
if self._elapsed_steps >= self._max_episode_steps:
|
||||
done = True
|
||||
info['TimeLimit.truncated'] = True
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._elapsed_steps = 0
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
class StochasticFrameSkip(gym.Wrapper):
|
||||
def __init__(self, env, n, stickprob):
|
||||
@@ -99,7 +85,7 @@ class Downsample(gym.ObservationWrapper):
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
(oldh, oldw, oldc) = env.observation_space.shape
|
||||
newshape = (oldh//ratio, oldw//ratio, oldc)
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255,
|
||||
shape=newshape, dtype=np.uint8)
|
||||
|
||||
def observation(self, frame):
|
||||
@@ -116,7 +102,7 @@ class Rgb2gray(gym.ObservationWrapper):
|
||||
"""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
(oldh, oldw, _oldc) = env.observation_space.shape
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255,
|
||||
shape=(oldh, oldw, 1), dtype=np.uint8)
|
||||
|
||||
def observation(self, frame):
|
||||
@@ -213,8 +199,10 @@ class StartDoingRandomActionsWrapper(gym.Wrapper):
|
||||
self.some_random_steps()
|
||||
return self.last_obs, rew, done, info
|
||||
|
||||
def make_retro(*, game, state, max_episode_steps, **kwargs):
|
||||
def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs):
|
||||
import retro
|
||||
if state is None:
|
||||
state = retro.State.DEFAULT
|
||||
env = retro.make(game, state, **kwargs)
|
||||
env = StochasticFrameSkip(env, n=4, stickprob=0.25)
|
||||
if max_episode_steps is not None:
|
||||
|
27
baselines/common/test_mpi_util.py
Normal file
27
baselines/common/test_mpi_util.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from baselines import logger
|
||||
from baselines.common.tests.test_with_mpi import with_mpi
|
||||
from baselines.common import mpi_util
|
||||
|
||||
@with_mpi()
|
||||
def test_mpi_weighted_mean():
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
with logger.scoped_configure(comm=comm):
|
||||
if comm.rank == 0:
|
||||
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
|
||||
elif comm.rank == 1:
|
||||
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
|
||||
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
|
||||
if comm.rank == 0:
|
||||
assert d == correctval, f'{d} != {correctval}'
|
||||
|
||||
for name, (val, count) in name2valcount.items():
|
||||
for _ in range(count):
|
||||
logger.logkv_mean(name, val)
|
||||
d2 = logger.dumpkvs()
|
||||
if comm.rank == 0:
|
||||
assert d2 == correctval
|
@@ -7,21 +7,20 @@ class FixedSequenceEnv(Env):
|
||||
def __init__(
|
||||
self,
|
||||
n_actions=10,
|
||||
seed=0,
|
||||
episode_len=100
|
||||
):
|
||||
self.np_random = np.random.RandomState()
|
||||
self.np_random.seed(seed)
|
||||
self.sequence = [self.np_random.randint(0, n_actions-1) for _ in range(episode_len)]
|
||||
self.sequence = None
|
||||
|
||||
self.action_space = Discrete(n_actions)
|
||||
self.observation_space = Discrete(1)
|
||||
|
||||
self.episode_len = episode_len
|
||||
self.time = 0
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
if self.sequence is None:
|
||||
self.sequence = [self.np_random.randint(0, self.action_space.n-1) for _ in range(self.episode_len)]
|
||||
self.time = 0
|
||||
return 0
|
||||
|
||||
@@ -35,6 +34,9 @@ class FixedSequenceEnv(Env):
|
||||
|
||||
return 0, rew, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random.seed(seed)
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.time += 1
|
||||
|
||||
|
@@ -10,6 +10,7 @@ class IdentityEnv(Env):
|
||||
episode_len=None
|
||||
):
|
||||
|
||||
self.observation_space = self.action_space
|
||||
self.episode_len = episode_len
|
||||
self.time = 0
|
||||
self.reset()
|
||||
@@ -17,7 +18,6 @@ class IdentityEnv(Env):
|
||||
def reset(self):
|
||||
self._choose_next_state()
|
||||
self.time = 0
|
||||
self.observation_space = self.action_space
|
||||
|
||||
return self.state
|
||||
|
||||
@@ -26,11 +26,13 @@ class IdentityEnv(Env):
|
||||
self._choose_next_state()
|
||||
done = False
|
||||
if self.episode_len and self.time >= self.episode_len:
|
||||
rew = 0
|
||||
done = True
|
||||
|
||||
return self.state, rew, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.action_space.seed(seed)
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.state = self.action_space.sample()
|
||||
self.time += 1
|
||||
@@ -74,7 +76,7 @@ class BoxIdentityEnv(IdentityEnv):
|
||||
episode_len=None,
|
||||
):
|
||||
|
||||
self.action_space = Box(low=-1.0, high=1.0, shape=shape)
|
||||
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
|
||||
super().__init__(episode_len=episode_len)
|
||||
|
||||
def _get_reward(self, actions):
|
||||
|
@@ -9,7 +9,6 @@ from gym.spaces import Discrete, Box
|
||||
class MnistEnv(Env):
|
||||
def __init__(
|
||||
self,
|
||||
seed=0,
|
||||
episode_len=None,
|
||||
no_images=None
|
||||
):
|
||||
@@ -23,7 +22,6 @@ class MnistEnv(Env):
|
||||
self.mnist = input_data.read_data_sets(mnist_path)
|
||||
|
||||
self.np_random = np.random.RandomState()
|
||||
self.np_random.seed(seed)
|
||||
|
||||
self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
|
||||
self.action_space = Discrete(10)
|
||||
@@ -50,6 +48,9 @@ class MnistEnv(Env):
|
||||
|
||||
return self.state[0], rew, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random.seed(seed)
|
||||
|
||||
def train_mode(self):
|
||||
self.dataset = self.mnist.train
|
||||
|
||||
|
@@ -33,8 +33,7 @@ def test_fixed_sequence(alg, rnn):
|
||||
kwargs = learn_kwargs[alg]
|
||||
kwargs.update(common_kwargs)
|
||||
|
||||
episode_len = 5
|
||||
env_fn = lambda: FixedSequenceEnv(10, episode_len=episode_len)
|
||||
env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
|
||||
learn = lambda e: get_learn_function(alg)(
|
||||
env=e,
|
||||
network=rnn,
|
||||
|
@@ -41,7 +41,7 @@ def test_mnist(alg):
|
||||
|
||||
learn = get_learn_function(alg)
|
||||
learn_fn = lambda e: learn(env=e, **learn_kwargs)
|
||||
env_fn = lambda: MnistEnv(seed=0, episode_len=100)
|
||||
env_fn = lambda: MnistEnv(episode_len=100)
|
||||
|
||||
simple_test(env_fn, learn_fn, 0.6)
|
||||
|
||||
|
@@ -44,7 +44,12 @@ def test_serialization(learn_fn, network_fn):
|
||||
# github issue: https://github.com/openai/baselines/issues/660
|
||||
return
|
||||
|
||||
env = DummyVecEnv([lambda: MnistEnv(10, episode_len=100)])
|
||||
def make_env():
|
||||
env = MnistEnv(episode_len=100)
|
||||
env.seed(10)
|
||||
return env
|
||||
|
||||
env = DummyVecEnv([make_env])
|
||||
ob = env.reset().copy()
|
||||
learn = get_learn_function(learn_fn)
|
||||
|
||||
|
36
baselines/common/tests/test_with_mpi.py
Normal file
36
baselines/common/tests/test_with_mpi.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import cloudpickle
|
||||
import base64
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
MPI = None
|
||||
|
||||
def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
|
||||
def outer_thunk(fn):
|
||||
def thunk(*args, **kwargs):
|
||||
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
|
||||
subprocess.check_call([
|
||||
'mpiexec','-n', str(nproc),
|
||||
sys.executable,
|
||||
'-m', 'baselines.common.tests.test_with_mpi',
|
||||
serialized_fn
|
||||
], env=os.environ, timeout=timeout)
|
||||
|
||||
if skip_if_no_mpi:
|
||||
return pytest.mark.skipif(MPI is None, reason="MPI not present")(thunk)
|
||||
else:
|
||||
return thunk
|
||||
|
||||
return outer_thunk
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) > 1:
|
||||
fn = cloudpickle.loads(base64.b64decode(sys.argv[1]))
|
||||
assert callable(fn)
|
||||
fn()
|
@@ -6,48 +6,39 @@ N_TRIALS = 10000
|
||||
N_EPISODES = 100
|
||||
|
||||
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||
def seeded_env_fn():
|
||||
env = env_fn()
|
||||
env.seed(0)
|
||||
return env
|
||||
|
||||
np.random.seed(0)
|
||||
env = DummyVecEnv([env_fn])
|
||||
|
||||
|
||||
env = DummyVecEnv([seeded_env_fn])
|
||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||
tf.set_random_seed(0)
|
||||
|
||||
model = learn_fn(env)
|
||||
|
||||
sum_rew = 0
|
||||
done = True
|
||||
|
||||
for i in range(n_trials):
|
||||
if done:
|
||||
obs = env.reset()
|
||||
state = model.initial_state
|
||||
|
||||
if state is not None:
|
||||
a, v, state, _ = model.step(obs, S=state, M=[False])
|
||||
else:
|
||||
a, v, _, _ = model.step(obs)
|
||||
|
||||
obs, rew, done, _ = env.step(a)
|
||||
sum_rew += float(rew)
|
||||
|
||||
print("Reward in {} trials is {}".format(n_trials, sum_rew))
|
||||
assert sum_rew > min_reward_fraction * n_trials, \
|
||||
'sum of rewards {} is less than {} of the total number of trials {}'.format(sum_rew, min_reward_fraction, n_trials)
|
||||
|
||||
|
||||
|
||||
def reward_per_episode_test(env_fn, learn_fn, min_avg_reward, n_trials=N_EPISODES):
|
||||
env = DummyVecEnv([env_fn])
|
||||
|
||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||
model = learn_fn(env)
|
||||
|
||||
N_TRIALS = 100
|
||||
|
||||
observations, actions, rewards = rollout(env, model, N_TRIALS)
|
||||
rewards = [sum(r) for r in rewards]
|
||||
|
||||
avg_rew = sum(rewards) / N_TRIALS
|
||||
print("Average reward in {} episodes is {}".format(n_trials, avg_rew))
|
||||
assert avg_rew > min_avg_reward, \
|
||||
@@ -57,14 +48,12 @@ def rollout(env, model, n_trials):
|
||||
rewards = []
|
||||
actions = []
|
||||
observations = []
|
||||
|
||||
for i in range(n_trials):
|
||||
obs = env.reset()
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
episode_rew = []
|
||||
episode_actions = []
|
||||
episode_obs = []
|
||||
|
||||
while True:
|
||||
if state is not None:
|
||||
a, v, state, _ = model.step(obs, S=state, M=[False])
|
||||
@@ -72,17 +61,13 @@ def rollout(env, model, n_trials):
|
||||
a,v, _, _ = model.step(obs)
|
||||
|
||||
obs, rew, done, _ = env.step(a)
|
||||
|
||||
episode_rew.append(rew)
|
||||
episode_actions.append(a)
|
||||
episode_obs.append(obs)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
rewards.append(episode_rew)
|
||||
actions.append(episode_actions)
|
||||
observations.append(episode_obs)
|
||||
|
||||
return observations, actions, rewards
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import joblib
|
||||
import numpy as np
|
||||
import tensorflow as tf # pylint: ignore-module
|
||||
import copy
|
||||
@@ -339,6 +338,7 @@ def save_state(fname, sess=None):
|
||||
# TODO: ensure there is no subtle differences and remove one
|
||||
|
||||
def save_variables(save_path, variables=None, sess=None):
|
||||
import joblib
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
|
||||
@@ -350,6 +350,7 @@ def save_variables(save_path, variables=None, sess=None):
|
||||
joblib.dump(save_dict, save_path)
|
||||
|
||||
def load_variables(load_path, variables=None, sess=None):
|
||||
import joblib
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
|
||||
|
@@ -1,185 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from baselines.common.tile_images import tile_images
|
||||
from .vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, VecEnvObservationWrapper, CloudpickleWrapper
|
||||
from .dummy_vec_env import DummyVecEnv
|
||||
from .shmem_vec_env import ShmemVecEnv
|
||||
from .subproc_vec_env import SubprocVecEnv
|
||||
from .vec_frame_stack import VecFrameStack
|
||||
from .vec_monitor import VecMonitor
|
||||
from .vec_normalize import VecNormalize
|
||||
from .vec_remove_dict_obs import VecExtractDictObs
|
||||
|
||||
class AlreadySteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is running while
|
||||
step_async() is called again.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = 'already running an async step'
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class NotSteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is not running but
|
||||
step_wait() is called.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = 'not running an async step'
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class VecEnv(ABC):
|
||||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
Used to batch data from multiple copies of an environment, so that
|
||||
each observation becomes an batch of observations, and expected action is a batch of actions to
|
||||
be applied per-environment.
|
||||
"""
|
||||
closed = False
|
||||
viewer = None
|
||||
|
||||
metadata = {
|
||||
'render.modes': ['human', 'rgb_array']
|
||||
}
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all the environments and return an array of
|
||||
observations, or a dict of observation arrays.
|
||||
|
||||
If step_async is still doing work, that work will
|
||||
be cancelled and step_wait() should not be called
|
||||
until step_async() is invoked again.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_async(self, actions):
|
||||
"""
|
||||
Tell all the environments to start taking a step
|
||||
with the given actions.
|
||||
Call step_wait() to get the results of the step.
|
||||
|
||||
You should not call this if a step_async run is
|
||||
already pending.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
"""
|
||||
Wait for the step taken with step_async().
|
||||
|
||||
Returns (obs, rews, dones, infos):
|
||||
- obs: an array of observations, or a dict of
|
||||
arrays of observations.
|
||||
- rews: an array of rewards
|
||||
- dones: an array of "episode done" booleans
|
||||
- infos: a sequence of info objects
|
||||
"""
|
||||
pass
|
||||
|
||||
def close_extras(self):
|
||||
"""
|
||||
Clean up the extra resources, beyond what's in this base class.
|
||||
Only runs when not self.closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
self.close_extras()
|
||||
self.closed = True
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Step the environments synchronously.
|
||||
|
||||
This is available for backwards compatibility.
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def render(self, mode='human'):
|
||||
imgs = self.get_images()
|
||||
bigimg = tile_images(imgs)
|
||||
if mode == 'human':
|
||||
self.get_viewer().imshow(bigimg)
|
||||
return self.get_viewer().isopen
|
||||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
Return RGB images from each environment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
if isinstance(self, VecEnvWrapper):
|
||||
return self.venv.unwrapped
|
||||
else:
|
||||
return self
|
||||
|
||||
def get_viewer(self):
|
||||
if self.viewer is None:
|
||||
from gym.envs.classic_control import rendering
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
return self.viewer
|
||||
|
||||
|
||||
class VecEnvWrapper(VecEnv):
|
||||
"""
|
||||
An environment wrapper that applies to an entire batch
|
||||
of environments at once.
|
||||
"""
|
||||
|
||||
def __init__(self, venv, observation_space=None, action_space=None):
|
||||
self.venv = venv
|
||||
VecEnv.__init__(self,
|
||||
num_envs=venv.num_envs,
|
||||
observation_space=observation_space or venv.observation_space,
|
||||
action_space=action_space or venv.action_space)
|
||||
|
||||
def step_async(self, actions):
|
||||
self.venv.step_async(actions)
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
return self.venv.close()
|
||||
|
||||
def render(self, mode='human'):
|
||||
return self.venv.render(mode=mode)
|
||||
|
||||
def get_images(self):
|
||||
return self.venv.get_images()
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.x)
|
||||
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
__all__ = ['AlreadySteppingError', 'NotSteppingError', 'VecEnv', 'VecEnvWrapper', 'VecEnvObservationWrapper', 'CloudpickleWrapper', 'DummyVecEnv', 'ShmemVecEnv', 'SubprocVecEnv', 'VecFrameStack', 'VecMonitor', 'VecNormalize', 'VecExtractDictObs']
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from . import VecEnv
|
||||
from .vec_env import VecEnv
|
||||
from .util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||
|
||||
class DummyVecEnv(VecEnv):
|
||||
@@ -27,7 +26,7 @@ class DummyVecEnv(VecEnv):
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
self.specs = [e.spec for e in self.envs]
|
||||
self.spec = self.envs[0].spec
|
||||
|
||||
def step_async(self, actions):
|
||||
listify = True
|
||||
@@ -46,8 +45,8 @@ class DummyVecEnv(VecEnv):
|
||||
def step_wait(self):
|
||||
for e in range(self.num_envs):
|
||||
action = self.actions[e]
|
||||
if isinstance(self.envs[e].action_space, spaces.Discrete):
|
||||
action = int(action)
|
||||
# if isinstance(self.envs[e].action_space, spaces.Discrete):
|
||||
# action = int(action)
|
||||
|
||||
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
|
||||
if self.buf_dones[e]:
|
||||
|
@@ -2,9 +2,9 @@
|
||||
An interface for asynchronous vectorized environments.
|
||||
"""
|
||||
|
||||
from multiprocessing import Pipe, Array, Process
|
||||
import multiprocessing as mp
|
||||
import numpy as np
|
||||
from . import VecEnv, CloudpickleWrapper
|
||||
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
|
||||
import ctypes
|
||||
from baselines import logger
|
||||
|
||||
@@ -22,11 +22,12 @@ class ShmemVecEnv(VecEnv):
|
||||
Optimized version of SubprocVecEnv that uses shared variables to communicate observations.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns, spaces=None):
|
||||
def __init__(self, env_fns, spaces=None, context='spawn'):
|
||||
"""
|
||||
If you don't specify observation_space, we'll have to create a dummy
|
||||
environment to get it.
|
||||
"""
|
||||
ctx = mp.get_context(context)
|
||||
if spaces:
|
||||
observation_space, action_space = spaces
|
||||
else:
|
||||
@@ -39,14 +40,15 @@ class ShmemVecEnv(VecEnv):
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||
self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space)
|
||||
self.obs_bufs = [
|
||||
{k: Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
|
||||
{k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
|
||||
for _ in env_fns]
|
||||
self.parent_pipes = []
|
||||
self.procs = []
|
||||
with clear_mpi_env_vars():
|
||||
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
|
||||
wrapped_fn = CloudpickleWrapper(env_fn)
|
||||
parent_pipe, child_pipe = Pipe()
|
||||
proc = Process(target=_subproc_worker,
|
||||
parent_pipe, child_pipe = ctx.Pipe()
|
||||
proc = ctx.Process(target=_subproc_worker,
|
||||
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
|
||||
proc.daemon = True
|
||||
self.procs.append(proc)
|
||||
@@ -54,7 +56,6 @@ class ShmemVecEnv(VecEnv):
|
||||
proc.start()
|
||||
child_pipe.close()
|
||||
self.waiting_step = False
|
||||
self.specs = [f().spec for f in env_fns]
|
||||
self.viewer = None
|
||||
|
||||
def reset(self):
|
||||
|
@@ -1,6 +1,8 @@
|
||||
import multiprocessing as mp
|
||||
|
||||
import numpy as np
|
||||
from multiprocessing import Process, Pipe
|
||||
from . import VecEnv, CloudpickleWrapper
|
||||
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
|
||||
|
||||
|
||||
def worker(remote, parent_remote, env_fn_wrapper):
|
||||
parent_remote.close()
|
||||
@@ -21,8 +23,8 @@ def worker(remote, parent_remote, env_fn_wrapper):
|
||||
elif cmd == 'close':
|
||||
remote.close()
|
||||
break
|
||||
elif cmd == 'get_spaces':
|
||||
remote.send((env.observation_space, env.action_space))
|
||||
elif cmd == 'get_spaces_spec':
|
||||
remote.send((env.observation_space, env.action_space, env.spec))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
@@ -36,7 +38,7 @@ class SubprocVecEnv(VecEnv):
|
||||
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
|
||||
Recommended to use when num_envs > 1 and step() can be a bottleneck.
|
||||
"""
|
||||
def __init__(self, env_fns, spaces=None):
|
||||
def __init__(self, env_fns, spaces=None, context='spawn'):
|
||||
"""
|
||||
Arguments:
|
||||
|
||||
@@ -45,19 +47,20 @@ class SubprocVecEnv(VecEnv):
|
||||
self.waiting = False
|
||||
self.closed = False
|
||||
nenvs = len(env_fns)
|
||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||
ctx = mp.get_context(context)
|
||||
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
|
||||
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
||||
for p in self.ps:
|
||||
p.daemon = True # if the main process crashes, we should not cause things to hang
|
||||
with clear_mpi_env_vars():
|
||||
p.start()
|
||||
for remote in self.work_remotes:
|
||||
remote.close()
|
||||
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
observation_space, action_space = self.remotes[0].recv()
|
||||
self.remotes[0].send(('get_spaces_spec', None))
|
||||
observation_space, action_space, self.spec = self.remotes[0].recv()
|
||||
self.viewer = None
|
||||
self.specs = [f().spec for f in env_fns]
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||
|
||||
def step_async(self, actions):
|
||||
@@ -99,16 +102,16 @@ class SubprocVecEnv(VecEnv):
|
||||
def _assert_not_closed(self):
|
||||
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
|
||||
|
||||
def __del__(self):
|
||||
if not self.closed:
|
||||
self.close()
|
||||
|
||||
def _flatten_obs(obs):
|
||||
assert isinstance(obs, list) or isinstance(obs, tuple)
|
||||
assert isinstance(obs, (list, tuple))
|
||||
assert len(obs) > 0
|
||||
|
||||
if isinstance(obs[0], dict):
|
||||
import collections
|
||||
assert isinstance(obs, collections.OrderedDict)
|
||||
keys = obs[0].keys()
|
||||
return {k: np.stack([o[k] for o in obs]) for k in keys}
|
||||
else:
|
||||
return np.stack(obs)
|
||||
|
||||
|
@@ -8,39 +8,40 @@ import pytest
|
||||
from .dummy_vec_env import DummyVecEnv
|
||||
from .shmem_vec_env import ShmemVecEnv
|
||||
from .subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.tests.test_with_mpi import with_mpi
|
||||
|
||||
|
||||
def assert_envs_equal(env1, env2, num_steps):
|
||||
def assert_venvs_equal(venv1, venv2, num_steps):
|
||||
"""
|
||||
Compare two environments over num_steps steps and make sure
|
||||
that the observations produced by each are the same when given
|
||||
the same actions.
|
||||
"""
|
||||
assert env1.num_envs == env2.num_envs
|
||||
assert env1.action_space.shape == env2.action_space.shape
|
||||
assert env1.action_space.dtype == env2.action_space.dtype
|
||||
joint_shape = (env1.num_envs,) + env1.action_space.shape
|
||||
assert venv1.num_envs == venv2.num_envs
|
||||
assert venv1.observation_space.shape == venv2.observation_space.shape
|
||||
assert venv1.observation_space.dtype == venv2.observation_space.dtype
|
||||
assert venv1.action_space.shape == venv2.action_space.shape
|
||||
assert venv1.action_space.dtype == venv2.action_space.dtype
|
||||
|
||||
try:
|
||||
obs1, obs2 = env1.reset(), env2.reset()
|
||||
obs1, obs2 = venv1.reset(), venv2.reset()
|
||||
assert np.array(obs1).shape == np.array(obs2).shape
|
||||
assert np.array(obs1).shape == joint_shape
|
||||
assert np.array(obs1).shape == (venv1.num_envs,) + venv1.observation_space.shape
|
||||
assert np.allclose(obs1, obs2)
|
||||
np.random.seed(1337)
|
||||
venv1.action_space.seed(1337)
|
||||
for _ in range(num_steps):
|
||||
actions = np.array(np.random.randint(0, 0x100, size=joint_shape),
|
||||
dtype=env1.action_space.dtype)
|
||||
for env in [env1, env2]:
|
||||
env.step_async(actions)
|
||||
outs1 = env1.step_wait()
|
||||
outs2 = env2.step_wait()
|
||||
actions = np.array([venv1.action_space.sample() for _ in range(venv1.num_envs)])
|
||||
for venv in [venv1, venv2]:
|
||||
venv.step_async(actions)
|
||||
outs1 = venv1.step_wait()
|
||||
outs2 = venv2.step_wait()
|
||||
for out1, out2 in zip(outs1[:3], outs2[:3]):
|
||||
assert np.array(out1).shape == np.array(out2).shape
|
||||
assert np.allclose(out1, out2)
|
||||
assert list(outs1[3]) == list(outs2[3])
|
||||
finally:
|
||||
env1.close()
|
||||
env2.close()
|
||||
venv1.close()
|
||||
venv2.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
|
||||
@@ -63,7 +64,7 @@ def test_vec_env(klass, dtype): # pylint: disable=R0914
|
||||
fns = [make_fn(i) for i in range(num_envs)]
|
||||
env1 = DummyVecEnv(fns)
|
||||
env2 = klass(fns)
|
||||
assert_envs_equal(env1, env2, num_steps=num_steps)
|
||||
assert_venvs_equal(env1, env2, num_steps=num_steps)
|
||||
|
||||
|
||||
class SimpleEnv(gym.Env):
|
||||
@@ -99,3 +100,15 @@ class SimpleEnv(gym.Env):
|
||||
|
||||
def render(self, mode=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@with_mpi()
|
||||
def test_mpi_with_subprocvecenv():
|
||||
shape = (2,3,4)
|
||||
nenv = 1
|
||||
venv = SubprocVecEnv([lambda: SimpleEnv(0, shape, 'float32')] * nenv)
|
||||
ob = venv.reset()
|
||||
venv.close()
|
||||
assert ob.shape == (nenv,) + shape
|
||||
|
||||
|
219
baselines/common/vec_env/vec_env.py
Normal file
219
baselines/common/vec_env/vec_env.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import contextlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from baselines.common.tile_images import tile_images
|
||||
|
||||
class AlreadySteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is running while
|
||||
step_async() is called again.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = 'already running an async step'
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class NotSteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is not running but
|
||||
step_wait() is called.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
msg = 'not running an async step'
|
||||
Exception.__init__(self, msg)
|
||||
|
||||
|
||||
class VecEnv(ABC):
|
||||
"""
|
||||
An abstract asynchronous, vectorized environment.
|
||||
Used to batch data from multiple copies of an environment, so that
|
||||
each observation becomes an batch of observations, and expected action is a batch of actions to
|
||||
be applied per-environment.
|
||||
"""
|
||||
closed = False
|
||||
viewer = None
|
||||
|
||||
metadata = {
|
||||
'render.modes': ['human', 'rgb_array']
|
||||
}
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Reset all the environments and return an array of
|
||||
observations, or a dict of observation arrays.
|
||||
|
||||
If step_async is still doing work, that work will
|
||||
be cancelled and step_wait() should not be called
|
||||
until step_async() is invoked again.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_async(self, actions):
|
||||
"""
|
||||
Tell all the environments to start taking a step
|
||||
with the given actions.
|
||||
Call step_wait() to get the results of the step.
|
||||
|
||||
You should not call this if a step_async run is
|
||||
already pending.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
"""
|
||||
Wait for the step taken with step_async().
|
||||
|
||||
Returns (obs, rews, dones, infos):
|
||||
- obs: an array of observations, or a dict of
|
||||
arrays of observations.
|
||||
- rews: an array of rewards
|
||||
- dones: an array of "episode done" booleans
|
||||
- infos: a sequence of info objects
|
||||
"""
|
||||
pass
|
||||
|
||||
def close_extras(self):
|
||||
"""
|
||||
Clean up the extra resources, beyond what's in this base class.
|
||||
Only runs when not self.closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
self.close_extras()
|
||||
self.closed = True
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Step the environments synchronously.
|
||||
|
||||
This is available for backwards compatibility.
|
||||
"""
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def render(self, mode='human'):
|
||||
imgs = self.get_images()
|
||||
bigimg = tile_images(imgs)
|
||||
if mode == 'human':
|
||||
self.get_viewer().imshow(bigimg)
|
||||
return self.get_viewer().isopen
|
||||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
Return RGB images from each environment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
if isinstance(self, VecEnvWrapper):
|
||||
return self.venv.unwrapped
|
||||
else:
|
||||
return self
|
||||
|
||||
def get_viewer(self):
|
||||
if self.viewer is None:
|
||||
from gym.envs.classic_control import rendering
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
return self.viewer
|
||||
|
||||
class VecEnvWrapper(VecEnv):
|
||||
"""
|
||||
An environment wrapper that applies to an entire batch
|
||||
of environments at once.
|
||||
"""
|
||||
|
||||
def __init__(self, venv, observation_space=None, action_space=None):
|
||||
self.venv = venv
|
||||
VecEnv.__init__(self,
|
||||
num_envs=venv.num_envs,
|
||||
observation_space=observation_space or venv.observation_space,
|
||||
action_space=action_space or venv.action_space)
|
||||
|
||||
def step_async(self, actions):
|
||||
self.venv.step_async(actions)
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step_wait(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
return self.venv.close()
|
||||
|
||||
def render(self, mode='human'):
|
||||
return self.venv.render(mode=mode)
|
||||
|
||||
def get_images(self):
|
||||
return self.venv.get_images()
|
||||
|
||||
class VecEnvObservationWrapper(VecEnvWrapper):
|
||||
@abstractmethod
|
||||
def process(self, obs):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
obs = self.venv.reset()
|
||||
return self.process(obs)
|
||||
|
||||
def step_wait(self):
|
||||
obs, rews, dones, infos = self.venv.step_wait()
|
||||
return self.process(obs), rews, dones, infos
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
"""
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.x)
|
||||
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def clear_mpi_env_vars():
|
||||
"""
|
||||
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
|
||||
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
|
||||
Processes.
|
||||
"""
|
||||
removed_environment = {}
|
||||
for k, v in list(os.environ.items()):
|
||||
for prefix in ['OMPI_', 'PMI_']:
|
||||
if k.startswith(prefix):
|
||||
removed_environment[k] = v
|
||||
del os.environ[k]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.update(removed_environment)
|
@@ -1,4 +1,4 @@
|
||||
from . import VecEnvWrapper
|
||||
from .vec_env import VecEnvWrapper
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
|
@@ -2,15 +2,23 @@ from . import VecEnvWrapper
|
||||
from baselines.bench.monitor import ResultsWriter
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from collections import deque
|
||||
|
||||
class VecMonitor(VecEnvWrapper):
|
||||
def __init__(self, venv, filename=None):
|
||||
def __init__(self, venv, filename=None, keep_buf=0):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.eprets = None
|
||||
self.eplens = None
|
||||
self.epcount = 0
|
||||
self.tstart = time.time()
|
||||
if filename:
|
||||
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
|
||||
else:
|
||||
self.results_writer = None
|
||||
self.keep_buf = keep_buf
|
||||
if self.keep_buf:
|
||||
self.epret_buf = deque([], maxlen=keep_buf)
|
||||
self.eplen_buf = deque([], maxlen=keep_buf)
|
||||
|
||||
def reset(self):
|
||||
obs = self.venv.reset()
|
||||
@@ -28,10 +36,14 @@ class VecMonitor(VecEnvWrapper):
|
||||
if done:
|
||||
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
|
||||
info['episode'] = epinfo
|
||||
if self.keep_buf:
|
||||
self.epret_buf.append(ret)
|
||||
self.eplen_buf.append(eplen)
|
||||
self.epcount += 1
|
||||
self.eprets[i] = 0
|
||||
self.eplens[i] = 0
|
||||
if self.results_writer:
|
||||
self.results_writer.write_row(epinfo)
|
||||
|
||||
newinfos.append(info)
|
||||
|
||||
return obs, rews, dones, newinfos
|
||||
|
11
baselines/common/vec_env/vec_remove_dict_obs.py
Normal file
11
baselines/common/vec_env/vec_remove_dict_obs.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .vec_env import VecEnvObservationWrapper
|
||||
|
||||
|
||||
class VecExtractDictObs(VecEnvObservationWrapper):
|
||||
def __init__(self, venv, key):
|
||||
self.key = key
|
||||
super().__init__(venv=venv,
|
||||
observation_space=venv.observation_space.spaces[self.key])
|
||||
|
||||
def process(self, obs):
|
||||
return obs[self.key]
|
19
baselines/common/wrappers.py
Normal file
19
baselines/common/wrappers.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import gym
|
||||
|
||||
class TimeLimit(gym.Wrapper):
|
||||
def __init__(self, env, max_episode_steps=None):
|
||||
super(TimeLimit, self).__init__(env)
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._elapsed_steps = 0
|
||||
|
||||
def step(self, ac):
|
||||
observation, reward, done, info = self.env.step(ac)
|
||||
self._elapsed_steps += 1
|
||||
if self._elapsed_steps >= self._max_episode_steps:
|
||||
done = True
|
||||
info['TimeLimit.truncated'] = True
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._elapsed_steps = 0
|
||||
return self.env.reset(**kwargs)
|
@@ -1,7 +1,10 @@
|
||||
from baselines.run import main as M
|
||||
from multiprocessing import Process
|
||||
import baselines.run
|
||||
|
||||
def _run(argstr):
|
||||
M(('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
|
||||
p = Process(target=baselines.run.main, args=('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
def test_popart():
|
||||
_run('--normalize_returns=True --popart=True')
|
||||
|
@@ -108,7 +108,7 @@ def learn(*, network, env, total_timesteps,
|
||||
|
||||
# Prepare params.
|
||||
params = config.DEFAULT_PARAMS
|
||||
env_name = env.specs[0].id
|
||||
env_name = env.spec.id
|
||||
params['env_name'] = env_name
|
||||
params['replay_strategy'] = replay_strategy
|
||||
if env_name in config.DEFAULT_ENV_PARAMS:
|
||||
|
@@ -7,6 +7,7 @@ import time
|
||||
import datetime
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
@@ -68,7 +69,8 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
||||
self.file.flush()
|
||||
|
||||
def _truncate(self, s):
|
||||
return s[:20] + '...' if len(s) > 23 else s
|
||||
maxlen = 30
|
||||
return s[:maxlen-3] + '...' if len(s) > maxlen else s
|
||||
|
||||
def writeseq(self, seq):
|
||||
seq = list(seq)
|
||||
@@ -195,13 +197,13 @@ def logkv(key, val):
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
If called many times, last value will be used.
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
get_current().logkv(key, val)
|
||||
|
||||
def logkv_mean(key, val):
|
||||
"""
|
||||
The same as logkv(), but if called many times, values averaged.
|
||||
"""
|
||||
Logger.CURRENT.logkv_mean(key, val)
|
||||
get_current().logkv_mean(key, val)
|
||||
|
||||
def logkvs(d):
|
||||
"""
|
||||
@@ -213,21 +215,18 @@ def logkvs(d):
|
||||
def dumpkvs():
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
|
||||
level: int. (see logger.py docs) If the global logger level is higher than
|
||||
the level argument here, don't print to stdout.
|
||||
"""
|
||||
Logger.CURRENT.dumpkvs()
|
||||
return get_current().dumpkvs()
|
||||
|
||||
def getkvs():
|
||||
return Logger.CURRENT.name2val
|
||||
return get_current().name2val
|
||||
|
||||
|
||||
def log(*args, level=INFO):
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
||||
"""
|
||||
Logger.CURRENT.log(*args, level=level)
|
||||
get_current().log(*args, level=level)
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
@@ -246,30 +245,29 @@ def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
Logger.CURRENT.set_level(level)
|
||||
get_current().set_level(level)
|
||||
|
||||
def set_comm(comm):
|
||||
get_current().set_comm(comm)
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
"""
|
||||
return Logger.CURRENT.get_dir()
|
||||
return get_current().get_dir()
|
||||
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
|
||||
class ProfileKV:
|
||||
"""
|
||||
Usage:
|
||||
with logger.ProfileKV("interesting_scope"):
|
||||
code
|
||||
"""
|
||||
def __init__(self, n):
|
||||
self.n = "wait_" + n
|
||||
def __enter__(self):
|
||||
self.t1 = time.time()
|
||||
def __exit__(self ,type, value, traceback):
|
||||
Logger.CURRENT.name2val[self.n] += time.time() - self.t1
|
||||
@contextmanager
|
||||
def profile_kv(scopename):
|
||||
logkey = 'wait_' + scopename
|
||||
tstart = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_current().name2val[logkey] += time.time() - tstart
|
||||
|
||||
def profile(n):
|
||||
"""
|
||||
@@ -279,7 +277,7 @@ def profile(n):
|
||||
"""
|
||||
def decorator_with_name(func):
|
||||
def func_wrapper(*args, **kwargs):
|
||||
with ProfileKV(n):
|
||||
with profile_kv(n):
|
||||
return func(*args, **kwargs)
|
||||
return func_wrapper
|
||||
return decorator_with_name
|
||||
@@ -289,17 +287,25 @@ def profile(n):
|
||||
# Backend
|
||||
# ================================================================
|
||||
|
||||
def get_current():
|
||||
if Logger.CURRENT is None:
|
||||
_configure_default_logger()
|
||||
|
||||
return Logger.CURRENT
|
||||
|
||||
|
||||
class Logger(object):
|
||||
DEFAULT = None # A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output files
|
||||
CURRENT = None # Current logger being used by the free functions above
|
||||
|
||||
def __init__(self, dir, output_formats):
|
||||
def __init__(self, dir, output_formats, comm=None):
|
||||
self.name2val = defaultdict(float) # values this iteration
|
||||
self.name2cnt = defaultdict(int)
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
self.comm = comm
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
@@ -307,20 +313,27 @@ class Logger(object):
|
||||
self.name2val[key] = val
|
||||
|
||||
def logkv_mean(self, key, val):
|
||||
if val is None:
|
||||
self.name2val[key] = None
|
||||
return
|
||||
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
||||
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
|
||||
self.name2cnt[key] = cnt + 1
|
||||
|
||||
def dumpkvs(self):
|
||||
if self.level == DISABLED: return
|
||||
if self.comm is None:
|
||||
d = self.name2val
|
||||
else:
|
||||
from baselines.common import mpi_util
|
||||
d = mpi_util.mpi_weighted_mean(self.comm,
|
||||
{name : (val, self.name2cnt.get(name, 1))
|
||||
for (name, val) in self.name2val.items()})
|
||||
if self.comm.rank != 0:
|
||||
d['dummy'] = 1 # so we don't get a warning about empty dict
|
||||
out = d.copy() # Return the dict for unit testing purposes
|
||||
for fmt in self.output_formats:
|
||||
if isinstance(fmt, KVWriter):
|
||||
fmt.writekvs(self.name2val)
|
||||
fmt.writekvs(d)
|
||||
self.name2val.clear()
|
||||
self.name2cnt.clear()
|
||||
return out
|
||||
|
||||
def log(self, *args, level=INFO):
|
||||
if self.level <= level:
|
||||
@@ -331,6 +344,9 @@ class Logger(object):
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
|
||||
def set_comm(self, comm):
|
||||
self.comm = comm
|
||||
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
@@ -345,7 +361,10 @@ class Logger(object):
|
||||
if isinstance(fmt, SeqWriter):
|
||||
fmt.writeseq(map(str, args))
|
||||
|
||||
def configure(dir=None, format_strs=None):
|
||||
def configure(dir=None, format_strs=None, comm=None):
|
||||
"""
|
||||
If comm is provided, average all numerical stats across that comm
|
||||
"""
|
||||
if dir is None:
|
||||
dir = os.getenv('OPENAI_LOGDIR')
|
||||
if dir is None:
|
||||
@@ -372,15 +391,11 @@ def configure(dir=None, format_strs=None):
|
||||
format_strs = filter(None, format_strs)
|
||||
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
|
||||
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
|
||||
log('Logging to %s'%dir)
|
||||
|
||||
def _configure_default_logger():
|
||||
format_strs = None
|
||||
# keep the old default of only writing to stdout
|
||||
if 'OPENAI_LOG_FORMAT' not in os.environ:
|
||||
format_strs = ['stdout']
|
||||
configure(format_strs=format_strs)
|
||||
configure()
|
||||
Logger.DEFAULT = Logger.CURRENT
|
||||
|
||||
def reset():
|
||||
@@ -389,17 +404,15 @@ def reset():
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
log('Reset logger')
|
||||
|
||||
class scoped_configure(object):
|
||||
def __init__(self, dir=None, format_strs=None):
|
||||
self.dir = dir
|
||||
self.format_strs = format_strs
|
||||
self.prevlogger = None
|
||||
def __enter__(self):
|
||||
self.prevlogger = Logger.CURRENT
|
||||
configure(dir=self.dir, format_strs=self.format_strs)
|
||||
def __exit__(self, *args):
|
||||
@contextmanager
|
||||
def scoped_configure(dir=None, format_strs=None, comm=None):
|
||||
prevlogger = Logger.CURRENT
|
||||
configure(dir=dir, format_strs=format_strs, comm=comm)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = self.prevlogger
|
||||
Logger.CURRENT = prevlogger
|
||||
|
||||
# ================================================================
|
||||
|
||||
@@ -423,7 +436,7 @@ def _demo():
|
||||
logkv_mean("b", -44.4)
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
info("^^^ should see b = 33.3")
|
||||
info("^^^ should see b = -33.3")
|
||||
|
||||
logkv("b", -2.5)
|
||||
dumpkvs()
|
||||
@@ -456,7 +469,6 @@ def read_tb(path):
|
||||
import pandas
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
if osp.isdir(path):
|
||||
fnames = glob(osp.join(path, "events.*"))
|
||||
@@ -482,8 +494,5 @@ def read_tb(path):
|
||||
data[step-1, colidx] = value
|
||||
return pandas.DataFrame(data, columns=tags)
|
||||
|
||||
# configure the default logger on import
|
||||
_configure_default_logger()
|
||||
|
||||
if __name__ == "__main__":
|
||||
_demo()
|
||||
|
@@ -97,7 +97,6 @@ def learn(env, policy_fn, *,
|
||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||
|
||||
lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
|
||||
clip_param = clip_param * lrmult # Annealed clipping parameter epsilon
|
||||
|
||||
ob = U.get_placeholder_cached(name="ob")
|
||||
ac = pi.pdtype.sample_placeholder([None])
|
||||
|
@@ -19,16 +19,17 @@ def train(num_timesteps, seed, model_path=None):
|
||||
# these are good enough to make humanoid walk, but whether those are
|
||||
# an absolute best or not is not certain
|
||||
env = RewScale(env, 0.1)
|
||||
logger.log("NOTE: reward will be scaled by a factor of 10 in logged stats. Check the monitor for unscaled reward.")
|
||||
pi = pposgd_simple.learn(env, policy_fn,
|
||||
max_timesteps=num_timesteps,
|
||||
timesteps_per_actorbatch=2048,
|
||||
clip_param=0.2, entcoeff=0.0,
|
||||
clip_param=0.1, entcoeff=0.0,
|
||||
optim_epochs=10,
|
||||
optim_stepsize=3e-4,
|
||||
optim_stepsize=1e-4,
|
||||
optim_batchsize=64,
|
||||
gamma=0.99,
|
||||
lam=0.95,
|
||||
schedule='linear',
|
||||
schedule='constant',
|
||||
)
|
||||
env.close()
|
||||
if model_path:
|
||||
@@ -47,7 +48,7 @@ def main():
|
||||
logger.configure()
|
||||
parser = mujoco_arg_parser()
|
||||
parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
|
||||
parser.set_defaults(num_timesteps=int(2e7))
|
||||
parser.set_defaults(num_timesteps=int(5e7))
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -68,8 +69,5 @@ def main():
|
||||
if done:
|
||||
ob = env.reset()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@@ -18,7 +18,7 @@ def atari():
|
||||
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
|
||||
ent_coef=.01,
|
||||
lr=lambda f : f * 2.5e-4,
|
||||
cliprange=lambda f : f * 0.1,
|
||||
cliprange=0.1,
|
||||
)
|
||||
|
||||
def retro():
|
||||
|
@@ -6,15 +6,13 @@ from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.vec_env import VecFrameStack, VecNormalize
|
||||
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||
from baselines.common.tf_util import get_session
|
||||
from baselines import logger
|
||||
from importlib import import_module
|
||||
|
||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
@@ -52,7 +50,7 @@ _game_envs['retro'] = {
|
||||
|
||||
|
||||
def train(args, extra_args):
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
env_type, env_id = get_env_type(args)
|
||||
print('env_type: {}'.format(env_type))
|
||||
|
||||
total_timesteps = int(args.num_timesteps)
|
||||
@@ -64,7 +62,7 @@ def train(args, extra_args):
|
||||
|
||||
env = build_env(args)
|
||||
if args.save_video_interval != 0:
|
||||
env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
|
||||
env = VecVideoRecorder(env, osp.join(logger.get_dir(), "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
|
||||
|
||||
if args.network:
|
||||
alg_kwargs['network'] = args.network
|
||||
@@ -91,7 +89,7 @@ def build_env(args):
|
||||
alg = args.alg
|
||||
seed = args.seed
|
||||
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
env_type, env_id = get_env_type(args)
|
||||
|
||||
if env_type in {'atari', 'retro'}:
|
||||
if alg == 'deepq':
|
||||
@@ -119,7 +117,12 @@ def build_env(args):
|
||||
return env
|
||||
|
||||
|
||||
def get_env_type(env_id):
|
||||
def get_env_type(args):
|
||||
env_id = args.env
|
||||
|
||||
if args.env_type is not None:
|
||||
return args.env_type, env_id
|
||||
|
||||
# Re-parse the gym registry, since we could have new envs since last time.
|
||||
for env in gym.envs.registry.all():
|
||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||
@@ -205,7 +208,6 @@ def main(args):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
model, env = train(args, extra_args)
|
||||
env.close()
|
||||
|
||||
if args.save_path is not None and rank == 0:
|
||||
save_path = osp.expanduser(args.save_path)
|
||||
@@ -213,23 +215,25 @@ def main(args):
|
||||
|
||||
if args.play:
|
||||
logger.log("Running trained model")
|
||||
env = build_env(args)
|
||||
obs = env.reset()
|
||||
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
dones = np.zeros((1,))
|
||||
|
||||
episode_rew = 0
|
||||
while True:
|
||||
if state is not None:
|
||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||
else:
|
||||
actions, _, _, _ = model.step(obs)
|
||||
|
||||
obs, _, done, _ = env.step(actions)
|
||||
obs, rew, done, _ = env.step(actions)
|
||||
episode_rew += rew[0]
|
||||
env.render()
|
||||
done = done.any() if isinstance(done, np.ndarray) else done
|
||||
|
||||
if done:
|
||||
print(f'episode_rew={episode_rew}')
|
||||
episode_rew = 0
|
||||
obs = env.reset()
|
||||
|
||||
env.close()
|
||||
|
Reference in New Issue
Block a user