1.5 months of codegen changes (#196)

* play with resnet

* feed_dict version

* coinrun prob and more stats

* fixes to get_choices_specs & hp search

* minor prob fixes

* minor fixes

* minor

* alternative version of rl_algo stuff

* pylint fixes

* fix bugs, move node_filters to soup

* changed how get_algo works

* change how get_algo works, probably broke all tests

* continue previous refactor

* get eval_agent running again

* fixing tests

* fix tests

* fix more tests

* clean up cma stuff

* fix experiment

* minor changes to eval_agent to make ppo_metal use gpu

* make dict space work

* modify mac makefile to use conda

* recurrent layers

* play with bn and resnets

* minor hp changes

* minor

* got rid of use_fb argument and jtft (joint-train-fine-tune) functionality
built test phase directly into AlgoProb

* make new rl algos generateable

* pylint; start fixing tests

* fixing tests

* more test fixes

* pylint

* fix search

* work on search

* hack around infinite loop caused by scan

* algo search fixes

* misc changes for search expt

* enable annealing, overriding options of Op

* pylint fixes

* identity op

* achieve use_last_output through masking so it automatically works in other distributions

* fix tests

* minor

* discrete

* use_last_output to be just a preference, not a hard constraint

* pred delay, pruning

* require nontrivial inputs

* aliases for get_sm

* add probname to probs

* fixes

* small fixes

* fix tests

* fix tests

* fix tests

* minor

* test scripts

* dualgru network improvements

* minor

* work on mysterious bugs

* rcall gpu-usage command for kube

* use cache dir that’s not in code folder, so that it doesn’t get removed by rcall code rsync

* add power mode to gpu usage

* make sure train/test actually different

* remove VR for now

* minor fixes

* simplify soln_db

* minor

* big refactor of mpi eda

* improve mpieda for multitask

* - get rid of timelimit hack
- add __del__ to cleanup SubprocVecEnv

* get multitask working better

* fixes

* working on atari, various

* annotate ops with whether they’re parametrized

* minor

* gym version

* rand atari prob

* minor

* SolnDb bugfix and name change

* pyspy script

* switch conv layers

* fix roboschool/bullet3

* nenvs assertion

* fix rand atari

* get rid of blanket exception catching
fix soln_db bug

* fix rand_atari

* dynamic routing as cmdline arg

* slight modifications to test_mpi_map and pyspy-all

* max_tries argument for run_until_successs

* dedup option in train_mle

* simplify soln_db

* increase atari horizon for 1 experiment

* start implementing reward increment

* ent multiplier

* create cc dsl
other misc fixes

* cc ops

* q_func -> qs in rl_algos_cc.py

* fix PredictDistr

* rl_ops_cc fixes, MakeAction op

* augment algo agent to support cc stuff

* work on ddpg experiments

* fix blocking
temporarily change logger

* allow layer scaling

* pylint fixes

* spawn_method

* isolate ddpg hacks

* improve pruning

* use spawn for subproc

* remove use of python -c in rcall

* fix pylint warning

* fix static

* maybe fix local backend

* switch to DummyVecEnv

* making some fixes via pylint

* pylint fixes

* fixing tests

* fix tests

* fix tests

* write scaffolding for SSL in Codegen

* logger fix

* fix error

* add EMA op to sl_ops

* save many changes

* save

* add upsampler

* add sl ops, enhance state machine

* get ssl search working — some gross hacking

* fix session/graph issue

* fix importing

* work on mle

* - scale embeddings in gru model
- better exception handling in sl_prob
- use emas for test/val
- use non-contrib batch_norm layer

* improve logging

* option to average before dumping in logger

* default arguments, etc

* new ddpg and identity test

* concat fix

* minor

* move realistic ssl stuff to third-party (underscore to dash)

* fixes

* remove realistic_ssl_evaluation

* pylint fixes

* use gym master

* try again

* pass around args without gin

* fix tests

* separate line to install gym

* rename failing tests that should be ignored

* add data aug

* ssl improvements

* use fixed time limit

* try to fix baselines tests

* add score_floor, max_walltime, fiddle with lr decay

* realistic_ssl

* autopep8

* various ssl
- enable blocking grad for simplification
- kl
- multiple final prediction

* fix pruning

* misc ssl stuff

* bring back linear schedule, don’t use allgather for collecting stats
(i’ve been getting nondeterministic errors from the old code)

* save/load weights in SSL, big stepsize

* cleanup SslProb

* fix

* get rid of kl coef

* fix simplification, lower lr

* search over hps

* minor fixes

* minor

* static analysis

* move files and rename things for improved consistency.
still broken, and just saving before making nontrivial changes

* various

* make tests pass

* move coinrun_train to codegen since it depends on codegen

* fixes

* pylint fixes

* improve tests
fix some things

* improve tests

* lint

* fix up db_info.py, tests

* mostly restore master version of envs directory, except for makefile changes

* fix tests

* improve printing

* minor fixes

* fix fixmes

* pruning test

* fixes

* lint

* write new test that makes tf graphs of random algos; fix some bugs it caught

* add —delete flag to rcall upload-code command

* lint

* get cifar10 lazily for testing purposes

* disable codegen ci tests for now

* clean up rl_ops

* rename spec classes

* td3 with identity test

* identity tests without gin files

* remove gin.configurable from AlgoAgent

* comments about reduction in rl_ops_cc

* address @pzhokhov comments

* fix tests

* more linting

* better tests

* clean up filtering a bit

* fix concat
This commit is contained in:
John Schulman
2019-01-03 13:23:18 -08:00
committed by Peter Zhokhov
parent 8fe79aa76d
commit 370ee27750
11 changed files with 117 additions and 31 deletions

View File

@@ -221,11 +221,8 @@ class LazyFrames(object):
def __getitem__(self, i): def __getitem__(self, i):
return self._force()[i] return self._force()[i]
def make_atari(env_id, timelimit=True): def make_atari(env_id):
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
env = gym.make(env_id) env = gym.make(env_id)
if not timelimit:
env = env.env
assert 'NoFrameskip' in env.spec.id assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30) env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4) env = MaxAndSkipEnv(env, skip=4)

View File

@@ -205,7 +205,8 @@ class CategoricalPd(Pd):
class MultiCategoricalPd(Pd): class MultiCategoricalPd(Pd):
def __init__(self, nvec, flat): def __init__(self, nvec, flat):
self.flat = flat self.flat = flat
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1))) self.categoricals = list(map(CategoricalPd,
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
def flatparam(self): def flatparam(self):
return self.flat return self.flat
def mode(self): def mode(self):

View File

@@ -4,6 +4,7 @@ import os, numpy as np
import platform import platform
import shutil import shutil
import subprocess import subprocess
import warnings
def sync_from_root(sess, variables, comm=None): def sync_from_root(sess, variables, comm=None):
""" """
@@ -81,6 +82,9 @@ def share_file(comm, path):
comm.Barrier() comm.Barrier()
def dict_gather(comm, d, op='mean', assert_all_have_data=True): def dict_gather(comm, d, op='mean', assert_all_have_data=True):
"""
Perform a reduction operation over dicts
"""
if comm is None: return d if comm is None: return d
alldicts = comm.allgather(d) alldicts = comm.allgather(d)
size = comm.size size = comm.size
@@ -99,3 +103,27 @@ def dict_gather(comm, d, op='mean', assert_all_have_data=True):
else: else:
assert 0, op assert 0, op
return result return result
def mpi_weighted_mean(comm, local_name2valcount):
"""
Perform a weighted average over dicts that are each on a different node
Input: local_name2valcount: dict mapping key -> (value, count)
Returns: key -> mean
"""
all_name2valcount = comm.gather(local_name2valcount)
if comm.rank == 0:
name2sum = defaultdict(float)
name2count = defaultdict(float)
for n2vc in all_name2valcount:
for (name, (val, count)) in n2vc.items():
try:
val = float(val)
except ValueError:
if comm.rank == 0:
warnings.warn(f'WARNING: tried to compute mean on non-float {name}={val}')
else:
name2sum[name] += val * count
name2count[name] += count
return {name : name2sum[name] / name2count[name] for name in name2sum}
else:
return {}

View File

@@ -0,0 +1,31 @@
from baselines.common import mpi_util
from mpi4py import MPI
import subprocess
import sys
from baselines import logger
def helper_for_mpi_weighted_mean():
comm = MPI.COMM_WORLD
if comm.rank == 0:
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
elif comm.rank == 1:
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
else:
raise NotImplementedError
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
if comm.rank == 0:
assert d == correctval, f'{d} != {correctval}'
for name, (val, count) in name2valcount.items():
for _ in range(count):
logger.logkv_mean(name, val)
d2 = logger.dumpkvs(mpi_mean=True)
if comm.rank == 0:
assert d2 == correctval
def test_mpi_weighted_mean():
subprocess.check_call(['mpirun', '-n', '2', sys.executable, '-c',
'from baselines.common import test_mpi_util; test_mpi_util.helper_for_mpi_weighted_mean()'])

View File

@@ -26,7 +26,6 @@ class IdentityEnv(Env):
self._choose_next_state() self._choose_next_state()
done = False done = False
if self.episode_len and self.time >= self.episode_len: if self.episode_len and self.time >= self.episode_len:
rew = 0
done = True done = True
return self.state, rew, done, {} return self.state, rew, done, {}
@@ -74,7 +73,7 @@ class BoxIdentityEnv(IdentityEnv):
episode_len=None, episode_len=None,
): ):
self.action_space = Box(low=-1.0, high=1.0, shape=shape) self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
super().__init__(episode_len=episode_len) super().__init__(episode_len=episode_len)
def _get_reward(self, actions): def _get_reward(self, actions):

View File

@@ -168,6 +168,19 @@ class VecEnvWrapper(VecEnv):
def get_images(self): def get_images(self):
return self.venv.get_images() return self.venv.get_images()
class VecEnvObservationWrapper(VecEnvWrapper):
@abstractmethod
def process(self, obs):
pass
def reset(self):
obs = self.venv.reset()
return self.process(obs)
def step_wait(self):
obs, rews, dones, infos = self.venv.step_wait()
return self.process(obs), rews, dones, infos
class CloudpickleWrapper(object): class CloudpickleWrapper(object):
""" """
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)

View File

@@ -27,7 +27,7 @@ class DummyVecEnv(VecEnv):
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)] self.buf_infos = [{} for _ in range(self.num_envs)]
self.actions = None self.actions = None
self.specs = [e.spec for e in self.envs] self.spec = self.envs[0].spec
def step_async(self, actions): def step_async(self, actions):
listify = True listify = True

View File

@@ -54,7 +54,6 @@ class ShmemVecEnv(VecEnv):
proc.start() proc.start()
child_pipe.close() child_pipe.close()
self.waiting_step = False self.waiting_step = False
self.specs = [f().spec for f in env_fns]
self.viewer = None self.viewer = None
def reset(self): def reset(self):

View File

@@ -1,7 +1,9 @@
import numpy as np import numpy as np
from multiprocessing import Process, Pipe import multiprocessing as mp
from . import VecEnv, CloudpickleWrapper from . import VecEnv, CloudpickleWrapper
ctx = mp.get_context('spawn')
def worker(remote, parent_remote, env_fn_wrapper): def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close() parent_remote.close()
env = env_fn_wrapper.x() env = env_fn_wrapper.x()
@@ -21,8 +23,8 @@ def worker(remote, parent_remote, env_fn_wrapper):
elif cmd == 'close': elif cmd == 'close':
remote.close() remote.close()
break break
elif cmd == 'get_spaces': elif cmd == 'get_spaces_spec':
remote.send((env.observation_space, env.action_space)) remote.send((env.observation_space, env.action_space, env.spec))
else: else:
raise NotImplementedError raise NotImplementedError
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -45,8 +47,8 @@ class SubprocVecEnv(VecEnv):
self.waiting = False self.waiting = False
self.closed = False self.closed = False
nenvs = len(env_fns) nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps: for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang p.daemon = True # if the main process crashes, we should not cause things to hang
@@ -54,10 +56,9 @@ class SubprocVecEnv(VecEnv):
for remote in self.work_remotes: for remote in self.work_remotes:
remote.close() remote.close()
self.remotes[0].send(('get_spaces', None)) self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space = self.remotes[0].recv() observation_space, action_space, self.spec = self.remotes[0].recv()
self.viewer = None self.viewer = None
self.specs = [f().spec for f in env_fns]
VecEnv.__init__(self, len(env_fns), observation_space, action_space) VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions): def step_async(self, actions):
@@ -99,9 +100,12 @@ class SubprocVecEnv(VecEnv):
def _assert_not_closed(self): def _assert_not_closed(self):
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
def __del__(self):
if not self.closed:
self.close()
def _flatten_obs(obs): def _flatten_obs(obs):
assert isinstance(obs, list) or isinstance(obs, tuple) assert isinstance(obs, (list, tuple))
assert len(obs) > 0 assert len(obs) > 0
if isinstance(obs[0], dict): if isinstance(obs[0], dict):
@@ -111,4 +115,3 @@ def _flatten_obs(obs):
return {k: np.stack([o[k] for o in obs]) for k in keys} return {k: np.stack([o[k] for o in obs]) for k in keys}
else: else:
return np.stack(obs) return np.stack(obs)

View File

@@ -108,7 +108,7 @@ def learn(*, network, env, total_timesteps,
# Prepare params. # Prepare params.
params = config.DEFAULT_PARAMS params = config.DEFAULT_PARAMS
env_name = env.specs[0].id env_name = env.spec.id
params['env_name'] = env_name params['env_name'] = env_name
params['replay_strategy'] = replay_strategy params['replay_strategy'] = replay_strategy
if env_name in config.DEFAULT_ENV_PARAMS: if env_name in config.DEFAULT_ENV_PARAMS:

View File

@@ -68,7 +68,8 @@ class HumanOutputFormat(KVWriter, SeqWriter):
self.file.flush() self.file.flush()
def _truncate(self, s): def _truncate(self, s):
return s[:20] + '...' if len(s) > 23 else s maxlen = 30
return s[:maxlen-3] + '...' if len(s) > maxlen else s
def writeseq(self, seq): def writeseq(self, seq):
seq = list(seq) seq = list(seq)
@@ -210,14 +211,17 @@ def logkvs(d):
for (k, v) in d.items(): for (k, v) in d.items():
logkv(k, v) logkv(k, v)
def dumpkvs(): def dumpkvs(mpi_mean=False):
""" """
Write all of the diagnostics from the current iteration Write all of the diagnostics from the current iteration
level: int. (see logger.py docs) If the global logger level is higher than mpi_mean: whether to average across MPI workers. mpi_mean=False just
the level argument here, don't print to stdout. has each worker write its own stats (and under default settings
non-root workers don't write anything), whereas mpi_mean=True has
the root worker collect all of the stats and write the average,
and no one else writes anything.
""" """
Logger.CURRENT.dumpkvs() return Logger.CURRENT.dumpkvs(mpi_mean=mpi_mean)
def getkvs(): def getkvs():
return Logger.CURRENT.name2val return Logger.CURRENT.name2val
@@ -307,20 +311,30 @@ class Logger(object):
self.name2val[key] = val self.name2val[key] = val
def logkv_mean(self, key, val): def logkv_mean(self, key, val):
if val is None:
self.name2val[key] = None
return
oldval, cnt = self.name2val[key], self.name2cnt[key] oldval, cnt = self.name2val[key], self.name2cnt[key]
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1) self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
self.name2cnt[key] = cnt + 1 self.name2cnt[key] = cnt + 1
def dumpkvs(self): def dumpkvs(self, mpi_mean=False):
if self.level == DISABLED: return if self.level == DISABLED: return
if mpi_mean:
from baselines.common import mpi_util
from mpi4py import MPI
comm = MPI.COMM_WORLD
d = mpi_util.mpi_weighted_mean(comm,
{name : (val, self.name2cnt.get(name, 1))
for (name, val) in self.name2val.items()})
if comm.rank != 0:
d['dummy'] = 1 # so we don't get a warning about empty dict
else:
d = self.name2val
out = d.copy() # Return the dict for unit testing purposes
for fmt in self.output_formats: for fmt in self.output_formats:
if isinstance(fmt, KVWriter): if isinstance(fmt, KVWriter):
fmt.writekvs(self.name2val) fmt.writekvs(d)
self.name2val.clear() self.name2val.clear()
self.name2cnt.clear() self.name2cnt.clear()
return out
def log(self, *args, level=INFO): def log(self, *args, level=INFO):
if self.level <= level: if self.level <= level:
@@ -456,7 +470,6 @@ def read_tb(path):
import pandas import pandas
import numpy as np import numpy as np
from glob import glob from glob import glob
from collections import defaultdict
import tensorflow as tf import tensorflow as tf
if osp.isdir(path): if osp.isdir(path):
fnames = glob(osp.join(path, "events.*")) fnames = glob(osp.join(path, "events.*"))
@@ -485,5 +498,7 @@ def read_tb(path):
# configure the default logger on import # configure the default logger on import
_configure_default_logger() _configure_default_logger()
if __name__ == "__main__": if __name__ == "__main__":
_demo() _demo()