1.5 months of codegen changes (#196)
* play with resnet * feed_dict version * coinrun prob and more stats * fixes to get_choices_specs & hp search * minor prob fixes * minor fixes * minor * alternative version of rl_algo stuff * pylint fixes * fix bugs, move node_filters to soup * changed how get_algo works * change how get_algo works, probably broke all tests * continue previous refactor * get eval_agent running again * fixing tests * fix tests * fix more tests * clean up cma stuff * fix experiment * minor changes to eval_agent to make ppo_metal use gpu * make dict space work * modify mac makefile to use conda * recurrent layers * play with bn and resnets * minor hp changes * minor * got rid of use_fb argument and jtft (joint-train-fine-tune) functionality built test phase directly into AlgoProb * make new rl algos generateable * pylint; start fixing tests * fixing tests * more test fixes * pylint * fix search * work on search * hack around infinite loop caused by scan * algo search fixes * misc changes for search expt * enable annealing, overriding options of Op * pylint fixes * identity op * achieve use_last_output through masking so it automatically works in other distributions * fix tests * minor * discrete * use_last_output to be just a preference, not a hard constraint * pred delay, pruning * require nontrivial inputs * aliases for get_sm * add probname to probs * fixes * small fixes * fix tests * fix tests * fix tests * minor * test scripts * dualgru network improvements * minor * work on mysterious bugs * rcall gpu-usage command for kube * use cache dir that’s not in code folder, so that it doesn’t get removed by rcall code rsync * add power mode to gpu usage * make sure train/test actually different * remove VR for now * minor fixes * simplify soln_db * minor * big refactor of mpi eda * improve mpieda for multitask * - get rid of timelimit hack - add __del__ to cleanup SubprocVecEnv * get multitask working better * fixes * working on atari, various * annotate ops with whether they’re parametrized * minor * gym version * rand atari prob * minor * SolnDb bugfix and name change * pyspy script * switch conv layers * fix roboschool/bullet3 * nenvs assertion * fix rand atari * get rid of blanket exception catching fix soln_db bug * fix rand_atari * dynamic routing as cmdline arg * slight modifications to test_mpi_map and pyspy-all * max_tries argument for run_until_successs * dedup option in train_mle * simplify soln_db * increase atari horizon for 1 experiment * start implementing reward increment * ent multiplier * create cc dsl other misc fixes * cc ops * q_func -> qs in rl_algos_cc.py * fix PredictDistr * rl_ops_cc fixes, MakeAction op * augment algo agent to support cc stuff * work on ddpg experiments * fix blocking temporarily change logger * allow layer scaling * pylint fixes * spawn_method * isolate ddpg hacks * improve pruning * use spawn for subproc * remove use of python -c in rcall * fix pylint warning * fix static * maybe fix local backend * switch to DummyVecEnv * making some fixes via pylint * pylint fixes * fixing tests * fix tests * fix tests * write scaffolding for SSL in Codegen * logger fix * fix error * add EMA op to sl_ops * save many changes * save * add upsampler * add sl ops, enhance state machine * get ssl search working — some gross hacking * fix session/graph issue * fix importing * work on mle * - scale embeddings in gru model - better exception handling in sl_prob - use emas for test/val - use non-contrib batch_norm layer * improve logging * option to average before dumping in logger * default arguments, etc * new ddpg and identity test * concat fix * minor * move realistic ssl stuff to third-party (underscore to dash) * fixes * remove realistic_ssl_evaluation * pylint fixes * use gym master * try again * pass around args without gin * fix tests * separate line to install gym * rename failing tests that should be ignored * add data aug * ssl improvements * use fixed time limit * try to fix baselines tests * add score_floor, max_walltime, fiddle with lr decay * realistic_ssl * autopep8 * various ssl - enable blocking grad for simplification - kl - multiple final prediction * fix pruning * misc ssl stuff * bring back linear schedule, don’t use allgather for collecting stats (i’ve been getting nondeterministic errors from the old code) * save/load weights in SSL, big stepsize * cleanup SslProb * fix * get rid of kl coef * fix simplification, lower lr * search over hps * minor fixes * minor * static analysis * move files and rename things for improved consistency. still broken, and just saving before making nontrivial changes * various * make tests pass * move coinrun_train to codegen since it depends on codegen * fixes * pylint fixes * improve tests fix some things * improve tests * lint * fix up db_info.py, tests * mostly restore master version of envs directory, except for makefile changes * fix tests * improve printing * minor fixes * fix fixmes * pruning test * fixes * lint * write new test that makes tf graphs of random algos; fix some bugs it caught * add —delete flag to rcall upload-code command * lint * get cifar10 lazily for testing purposes * disable codegen ci tests for now * clean up rl_ops * rename spec classes * td3 with identity test * identity tests without gin files * remove gin.configurable from AlgoAgent * comments about reduction in rl_ops_cc * address @pzhokhov comments * fix tests * more linting * better tests * clean up filtering a bit * fix concat
This commit is contained in:
committed by
Peter Zhokhov
parent
8fe79aa76d
commit
370ee27750
@@ -221,11 +221,8 @@ class LazyFrames(object):
|
|||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
return self._force()[i]
|
return self._force()[i]
|
||||||
|
|
||||||
def make_atari(env_id, timelimit=True):
|
def make_atari(env_id):
|
||||||
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
|
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
if not timelimit:
|
|
||||||
env = env.env
|
|
||||||
assert 'NoFrameskip' in env.spec.id
|
assert 'NoFrameskip' in env.spec.id
|
||||||
env = NoopResetEnv(env, noop_max=30)
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
env = MaxAndSkipEnv(env, skip=4)
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
|
@@ -205,7 +205,8 @@ class CategoricalPd(Pd):
|
|||||||
class MultiCategoricalPd(Pd):
|
class MultiCategoricalPd(Pd):
|
||||||
def __init__(self, nvec, flat):
|
def __init__(self, nvec, flat):
|
||||||
self.flat = flat
|
self.flat = flat
|
||||||
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1)))
|
self.categoricals = list(map(CategoricalPd,
|
||||||
|
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
|
||||||
def flatparam(self):
|
def flatparam(self):
|
||||||
return self.flat
|
return self.flat
|
||||||
def mode(self):
|
def mode(self):
|
||||||
|
@@ -4,6 +4,7 @@ import os, numpy as np
|
|||||||
import platform
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import warnings
|
||||||
|
|
||||||
def sync_from_root(sess, variables, comm=None):
|
def sync_from_root(sess, variables, comm=None):
|
||||||
"""
|
"""
|
||||||
@@ -81,6 +82,9 @@ def share_file(comm, path):
|
|||||||
comm.Barrier()
|
comm.Barrier()
|
||||||
|
|
||||||
def dict_gather(comm, d, op='mean', assert_all_have_data=True):
|
def dict_gather(comm, d, op='mean', assert_all_have_data=True):
|
||||||
|
"""
|
||||||
|
Perform a reduction operation over dicts
|
||||||
|
"""
|
||||||
if comm is None: return d
|
if comm is None: return d
|
||||||
alldicts = comm.allgather(d)
|
alldicts = comm.allgather(d)
|
||||||
size = comm.size
|
size = comm.size
|
||||||
@@ -99,3 +103,27 @@ def dict_gather(comm, d, op='mean', assert_all_have_data=True):
|
|||||||
else:
|
else:
|
||||||
assert 0, op
|
assert 0, op
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def mpi_weighted_mean(comm, local_name2valcount):
|
||||||
|
"""
|
||||||
|
Perform a weighted average over dicts that are each on a different node
|
||||||
|
Input: local_name2valcount: dict mapping key -> (value, count)
|
||||||
|
Returns: key -> mean
|
||||||
|
"""
|
||||||
|
all_name2valcount = comm.gather(local_name2valcount)
|
||||||
|
if comm.rank == 0:
|
||||||
|
name2sum = defaultdict(float)
|
||||||
|
name2count = defaultdict(float)
|
||||||
|
for n2vc in all_name2valcount:
|
||||||
|
for (name, (val, count)) in n2vc.items():
|
||||||
|
try:
|
||||||
|
val = float(val)
|
||||||
|
except ValueError:
|
||||||
|
if comm.rank == 0:
|
||||||
|
warnings.warn(f'WARNING: tried to compute mean on non-float {name}={val}')
|
||||||
|
else:
|
||||||
|
name2sum[name] += val * count
|
||||||
|
name2count[name] += count
|
||||||
|
return {name : name2sum[name] / name2count[name] for name in name2sum}
|
||||||
|
else:
|
||||||
|
return {}
|
31
baselines/common/test_mpi_util.py
Normal file
31
baselines/common/test_mpi_util.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from baselines.common import mpi_util
|
||||||
|
from mpi4py import MPI
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from baselines import logger
|
||||||
|
|
||||||
|
def helper_for_mpi_weighted_mean():
|
||||||
|
comm = MPI.COMM_WORLD
|
||||||
|
if comm.rank == 0:
|
||||||
|
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
|
||||||
|
elif comm.rank == 1:
|
||||||
|
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
|
||||||
|
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
|
||||||
|
if comm.rank == 0:
|
||||||
|
assert d == correctval, f'{d} != {correctval}'
|
||||||
|
|
||||||
|
for name, (val, count) in name2valcount.items():
|
||||||
|
for _ in range(count):
|
||||||
|
logger.logkv_mean(name, val)
|
||||||
|
d2 = logger.dumpkvs(mpi_mean=True)
|
||||||
|
if comm.rank == 0:
|
||||||
|
assert d2 == correctval
|
||||||
|
|
||||||
|
|
||||||
|
def test_mpi_weighted_mean():
|
||||||
|
subprocess.check_call(['mpirun', '-n', '2', sys.executable, '-c',
|
||||||
|
'from baselines.common import test_mpi_util; test_mpi_util.helper_for_mpi_weighted_mean()'])
|
@@ -26,7 +26,6 @@ class IdentityEnv(Env):
|
|||||||
self._choose_next_state()
|
self._choose_next_state()
|
||||||
done = False
|
done = False
|
||||||
if self.episode_len and self.time >= self.episode_len:
|
if self.episode_len and self.time >= self.episode_len:
|
||||||
rew = 0
|
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
return self.state, rew, done, {}
|
return self.state, rew, done, {}
|
||||||
@@ -74,7 +73,7 @@ class BoxIdentityEnv(IdentityEnv):
|
|||||||
episode_len=None,
|
episode_len=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.action_space = Box(low=-1.0, high=1.0, shape=shape)
|
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
|
||||||
super().__init__(episode_len=episode_len)
|
super().__init__(episode_len=episode_len)
|
||||||
|
|
||||||
def _get_reward(self, actions):
|
def _get_reward(self, actions):
|
||||||
|
@@ -168,6 +168,19 @@ class VecEnvWrapper(VecEnv):
|
|||||||
def get_images(self):
|
def get_images(self):
|
||||||
return self.venv.get_images()
|
return self.venv.get_images()
|
||||||
|
|
||||||
|
class VecEnvObservationWrapper(VecEnvWrapper):
|
||||||
|
@abstractmethod
|
||||||
|
def process(self, obs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
obs = self.venv.reset()
|
||||||
|
return self.process(obs)
|
||||||
|
|
||||||
|
def step_wait(self):
|
||||||
|
obs, rews, dones, infos = self.venv.step_wait()
|
||||||
|
return self.process(obs), rews, dones, infos
|
||||||
|
|
||||||
class CloudpickleWrapper(object):
|
class CloudpickleWrapper(object):
|
||||||
"""
|
"""
|
||||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||||
|
@@ -27,7 +27,7 @@ class DummyVecEnv(VecEnv):
|
|||||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||||
self.actions = None
|
self.actions = None
|
||||||
self.specs = [e.spec for e in self.envs]
|
self.spec = self.envs[0].spec
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
listify = True
|
listify = True
|
||||||
|
@@ -54,7 +54,6 @@ class ShmemVecEnv(VecEnv):
|
|||||||
proc.start()
|
proc.start()
|
||||||
child_pipe.close()
|
child_pipe.close()
|
||||||
self.waiting_step = False
|
self.waiting_step = False
|
||||||
self.specs = [f().spec for f in env_fns]
|
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from multiprocessing import Process, Pipe
|
import multiprocessing as mp
|
||||||
from . import VecEnv, CloudpickleWrapper
|
from . import VecEnv, CloudpickleWrapper
|
||||||
|
|
||||||
|
ctx = mp.get_context('spawn')
|
||||||
|
|
||||||
def worker(remote, parent_remote, env_fn_wrapper):
|
def worker(remote, parent_remote, env_fn_wrapper):
|
||||||
parent_remote.close()
|
parent_remote.close()
|
||||||
env = env_fn_wrapper.x()
|
env = env_fn_wrapper.x()
|
||||||
@@ -21,8 +23,8 @@ def worker(remote, parent_remote, env_fn_wrapper):
|
|||||||
elif cmd == 'close':
|
elif cmd == 'close':
|
||||||
remote.close()
|
remote.close()
|
||||||
break
|
break
|
||||||
elif cmd == 'get_spaces':
|
elif cmd == 'get_spaces_spec':
|
||||||
remote.send((env.observation_space, env.action_space))
|
remote.send((env.observation_space, env.action_space, env.spec))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -45,8 +47,8 @@ class SubprocVecEnv(VecEnv):
|
|||||||
self.waiting = False
|
self.waiting = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
nenvs = len(env_fns)
|
nenvs = len(env_fns)
|
||||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
|
||||||
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||||
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
||||||
for p in self.ps:
|
for p in self.ps:
|
||||||
p.daemon = True # if the main process crashes, we should not cause things to hang
|
p.daemon = True # if the main process crashes, we should not cause things to hang
|
||||||
@@ -54,10 +56,9 @@ class SubprocVecEnv(VecEnv):
|
|||||||
for remote in self.work_remotes:
|
for remote in self.work_remotes:
|
||||||
remote.close()
|
remote.close()
|
||||||
|
|
||||||
self.remotes[0].send(('get_spaces', None))
|
self.remotes[0].send(('get_spaces_spec', None))
|
||||||
observation_space, action_space = self.remotes[0].recv()
|
observation_space, action_space, self.spec = self.remotes[0].recv()
|
||||||
self.viewer = None
|
self.viewer = None
|
||||||
self.specs = [f().spec for f in env_fns]
|
|
||||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
@@ -99,9 +100,12 @@ class SubprocVecEnv(VecEnv):
|
|||||||
def _assert_not_closed(self):
|
def _assert_not_closed(self):
|
||||||
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
|
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if not self.closed:
|
||||||
|
self.close()
|
||||||
|
|
||||||
def _flatten_obs(obs):
|
def _flatten_obs(obs):
|
||||||
assert isinstance(obs, list) or isinstance(obs, tuple)
|
assert isinstance(obs, (list, tuple))
|
||||||
assert len(obs) > 0
|
assert len(obs) > 0
|
||||||
|
|
||||||
if isinstance(obs[0], dict):
|
if isinstance(obs[0], dict):
|
||||||
@@ -111,4 +115,3 @@ def _flatten_obs(obs):
|
|||||||
return {k: np.stack([o[k] for o in obs]) for k in keys}
|
return {k: np.stack([o[k] for o in obs]) for k in keys}
|
||||||
else:
|
else:
|
||||||
return np.stack(obs)
|
return np.stack(obs)
|
||||||
|
|
||||||
|
@@ -108,7 +108,7 @@ def learn(*, network, env, total_timesteps,
|
|||||||
|
|
||||||
# Prepare params.
|
# Prepare params.
|
||||||
params = config.DEFAULT_PARAMS
|
params = config.DEFAULT_PARAMS
|
||||||
env_name = env.specs[0].id
|
env_name = env.spec.id
|
||||||
params['env_name'] = env_name
|
params['env_name'] = env_name
|
||||||
params['replay_strategy'] = replay_strategy
|
params['replay_strategy'] = replay_strategy
|
||||||
if env_name in config.DEFAULT_ENV_PARAMS:
|
if env_name in config.DEFAULT_ENV_PARAMS:
|
||||||
|
@@ -68,7 +68,8 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||||||
self.file.flush()
|
self.file.flush()
|
||||||
|
|
||||||
def _truncate(self, s):
|
def _truncate(self, s):
|
||||||
return s[:20] + '...' if len(s) > 23 else s
|
maxlen = 30
|
||||||
|
return s[:maxlen-3] + '...' if len(s) > maxlen else s
|
||||||
|
|
||||||
def writeseq(self, seq):
|
def writeseq(self, seq):
|
||||||
seq = list(seq)
|
seq = list(seq)
|
||||||
@@ -210,14 +211,17 @@ def logkvs(d):
|
|||||||
for (k, v) in d.items():
|
for (k, v) in d.items():
|
||||||
logkv(k, v)
|
logkv(k, v)
|
||||||
|
|
||||||
def dumpkvs():
|
def dumpkvs(mpi_mean=False):
|
||||||
"""
|
"""
|
||||||
Write all of the diagnostics from the current iteration
|
Write all of the diagnostics from the current iteration
|
||||||
|
|
||||||
level: int. (see logger.py docs) If the global logger level is higher than
|
mpi_mean: whether to average across MPI workers. mpi_mean=False just
|
||||||
the level argument here, don't print to stdout.
|
has each worker write its own stats (and under default settings
|
||||||
|
non-root workers don't write anything), whereas mpi_mean=True has
|
||||||
|
the root worker collect all of the stats and write the average,
|
||||||
|
and no one else writes anything.
|
||||||
"""
|
"""
|
||||||
Logger.CURRENT.dumpkvs()
|
return Logger.CURRENT.dumpkvs(mpi_mean=mpi_mean)
|
||||||
|
|
||||||
def getkvs():
|
def getkvs():
|
||||||
return Logger.CURRENT.name2val
|
return Logger.CURRENT.name2val
|
||||||
@@ -307,20 +311,30 @@ class Logger(object):
|
|||||||
self.name2val[key] = val
|
self.name2val[key] = val
|
||||||
|
|
||||||
def logkv_mean(self, key, val):
|
def logkv_mean(self, key, val):
|
||||||
if val is None:
|
|
||||||
self.name2val[key] = None
|
|
||||||
return
|
|
||||||
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
||||||
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
|
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
|
||||||
self.name2cnt[key] = cnt + 1
|
self.name2cnt[key] = cnt + 1
|
||||||
|
|
||||||
def dumpkvs(self):
|
def dumpkvs(self, mpi_mean=False):
|
||||||
if self.level == DISABLED: return
|
if self.level == DISABLED: return
|
||||||
|
if mpi_mean:
|
||||||
|
from baselines.common import mpi_util
|
||||||
|
from mpi4py import MPI
|
||||||
|
comm = MPI.COMM_WORLD
|
||||||
|
d = mpi_util.mpi_weighted_mean(comm,
|
||||||
|
{name : (val, self.name2cnt.get(name, 1))
|
||||||
|
for (name, val) in self.name2val.items()})
|
||||||
|
if comm.rank != 0:
|
||||||
|
d['dummy'] = 1 # so we don't get a warning about empty dict
|
||||||
|
else:
|
||||||
|
d = self.name2val
|
||||||
|
out = d.copy() # Return the dict for unit testing purposes
|
||||||
for fmt in self.output_formats:
|
for fmt in self.output_formats:
|
||||||
if isinstance(fmt, KVWriter):
|
if isinstance(fmt, KVWriter):
|
||||||
fmt.writekvs(self.name2val)
|
fmt.writekvs(d)
|
||||||
self.name2val.clear()
|
self.name2val.clear()
|
||||||
self.name2cnt.clear()
|
self.name2cnt.clear()
|
||||||
|
return out
|
||||||
|
|
||||||
def log(self, *args, level=INFO):
|
def log(self, *args, level=INFO):
|
||||||
if self.level <= level:
|
if self.level <= level:
|
||||||
@@ -456,7 +470,6 @@ def read_tb(path):
|
|||||||
import pandas
|
import pandas
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from collections import defaultdict
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
if osp.isdir(path):
|
if osp.isdir(path):
|
||||||
fnames = glob(osp.join(path, "events.*"))
|
fnames = glob(osp.join(path, "events.*"))
|
||||||
@@ -485,5 +498,7 @@ def read_tb(path):
|
|||||||
# configure the default logger on import
|
# configure the default logger on import
|
||||||
_configure_default_logger()
|
_configure_default_logger()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_demo()
|
_demo()
|
||||||
|
Reference in New Issue
Block a user