1.5 months of codegen changes (#196)

* play with resnet

* feed_dict version

* coinrun prob and more stats

* fixes to get_choices_specs & hp search

* minor prob fixes

* minor fixes

* minor

* alternative version of rl_algo stuff

* pylint fixes

* fix bugs, move node_filters to soup

* changed how get_algo works

* change how get_algo works, probably broke all tests

* continue previous refactor

* get eval_agent running again

* fixing tests

* fix tests

* fix more tests

* clean up cma stuff

* fix experiment

* minor changes to eval_agent to make ppo_metal use gpu

* make dict space work

* modify mac makefile to use conda

* recurrent layers

* play with bn and resnets

* minor hp changes

* minor

* got rid of use_fb argument and jtft (joint-train-fine-tune) functionality
built test phase directly into AlgoProb

* make new rl algos generateable

* pylint; start fixing tests

* fixing tests

* more test fixes

* pylint

* fix search

* work on search

* hack around infinite loop caused by scan

* algo search fixes

* misc changes for search expt

* enable annealing, overriding options of Op

* pylint fixes

* identity op

* achieve use_last_output through masking so it automatically works in other distributions

* fix tests

* minor

* discrete

* use_last_output to be just a preference, not a hard constraint

* pred delay, pruning

* require nontrivial inputs

* aliases for get_sm

* add probname to probs

* fixes

* small fixes

* fix tests

* fix tests

* fix tests

* minor

* test scripts

* dualgru network improvements

* minor

* work on mysterious bugs

* rcall gpu-usage command for kube

* use cache dir that’s not in code folder, so that it doesn’t get removed by rcall code rsync

* add power mode to gpu usage

* make sure train/test actually different

* remove VR for now

* minor fixes

* simplify soln_db

* minor

* big refactor of mpi eda

* improve mpieda for multitask

* - get rid of timelimit hack
- add __del__ to cleanup SubprocVecEnv

* get multitask working better

* fixes

* working on atari, various

* annotate ops with whether they’re parametrized

* minor

* gym version

* rand atari prob

* minor

* SolnDb bugfix and name change

* pyspy script

* switch conv layers

* fix roboschool/bullet3

* nenvs assertion

* fix rand atari

* get rid of blanket exception catching
fix soln_db bug

* fix rand_atari

* dynamic routing as cmdline arg

* slight modifications to test_mpi_map and pyspy-all

* max_tries argument for run_until_successs

* dedup option in train_mle

* simplify soln_db

* increase atari horizon for 1 experiment

* start implementing reward increment

* ent multiplier

* create cc dsl
other misc fixes

* cc ops

* q_func -> qs in rl_algos_cc.py

* fix PredictDistr

* rl_ops_cc fixes, MakeAction op

* augment algo agent to support cc stuff

* work on ddpg experiments

* fix blocking
temporarily change logger

* allow layer scaling

* pylint fixes

* spawn_method

* isolate ddpg hacks

* improve pruning

* use spawn for subproc

* remove use of python -c in rcall

* fix pylint warning

* fix static

* maybe fix local backend

* switch to DummyVecEnv

* making some fixes via pylint

* pylint fixes

* fixing tests

* fix tests

* fix tests

* write scaffolding for SSL in Codegen

* logger fix

* fix error

* add EMA op to sl_ops

* save many changes

* save

* add upsampler

* add sl ops, enhance state machine

* get ssl search working — some gross hacking

* fix session/graph issue

* fix importing

* work on mle

* - scale embeddings in gru model
- better exception handling in sl_prob
- use emas for test/val
- use non-contrib batch_norm layer

* improve logging

* option to average before dumping in logger

* default arguments, etc

* new ddpg and identity test

* concat fix

* minor

* move realistic ssl stuff to third-party (underscore to dash)

* fixes

* remove realistic_ssl_evaluation

* pylint fixes

* use gym master

* try again

* pass around args without gin

* fix tests

* separate line to install gym

* rename failing tests that should be ignored

* add data aug

* ssl improvements

* use fixed time limit

* try to fix baselines tests

* add score_floor, max_walltime, fiddle with lr decay

* realistic_ssl

* autopep8

* various ssl
- enable blocking grad for simplification
- kl
- multiple final prediction

* fix pruning

* misc ssl stuff

* bring back linear schedule, don’t use allgather for collecting stats
(i’ve been getting nondeterministic errors from the old code)

* save/load weights in SSL, big stepsize

* cleanup SslProb

* fix

* get rid of kl coef

* fix simplification, lower lr

* search over hps

* minor fixes

* minor

* static analysis

* move files and rename things for improved consistency.
still broken, and just saving before making nontrivial changes

* various

* make tests pass

* move coinrun_train to codegen since it depends on codegen

* fixes

* pylint fixes

* improve tests
fix some things

* improve tests

* lint

* fix up db_info.py, tests

* mostly restore master version of envs directory, except for makefile changes

* fix tests

* improve printing

* minor fixes

* fix fixmes

* pruning test

* fixes

* lint

* write new test that makes tf graphs of random algos; fix some bugs it caught

* add —delete flag to rcall upload-code command

* lint

* get cifar10 lazily for testing purposes

* disable codegen ci tests for now

* clean up rl_ops

* rename spec classes

* td3 with identity test

* identity tests without gin files

* remove gin.configurable from AlgoAgent

* comments about reduction in rl_ops_cc

* address @pzhokhov comments

* fix tests

* more linting

* better tests

* clean up filtering a bit

* fix concat
This commit is contained in:
John Schulman
2019-01-03 13:23:18 -08:00
committed by Peter Zhokhov
parent 8fe79aa76d
commit 370ee27750
11 changed files with 117 additions and 31 deletions

View File

@@ -221,11 +221,8 @@ class LazyFrames(object):
def __getitem__(self, i):
return self._force()[i]
def make_atari(env_id, timelimit=True):
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
def make_atari(env_id):
env = gym.make(env_id)
if not timelimit:
env = env.env
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)

View File

@@ -205,7 +205,8 @@ class CategoricalPd(Pd):
class MultiCategoricalPd(Pd):
def __init__(self, nvec, flat):
self.flat = flat
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1)))
self.categoricals = list(map(CategoricalPd,
tf.split(flat, np.array(nvec, dtype=np.int32), axis=-1)))
def flatparam(self):
return self.flat
def mode(self):

View File

@@ -4,6 +4,7 @@ import os, numpy as np
import platform
import shutil
import subprocess
import warnings
def sync_from_root(sess, variables, comm=None):
"""
@@ -81,6 +82,9 @@ def share_file(comm, path):
comm.Barrier()
def dict_gather(comm, d, op='mean', assert_all_have_data=True):
"""
Perform a reduction operation over dicts
"""
if comm is None: return d
alldicts = comm.allgather(d)
size = comm.size
@@ -99,3 +103,27 @@ def dict_gather(comm, d, op='mean', assert_all_have_data=True):
else:
assert 0, op
return result
def mpi_weighted_mean(comm, local_name2valcount):
"""
Perform a weighted average over dicts that are each on a different node
Input: local_name2valcount: dict mapping key -> (value, count)
Returns: key -> mean
"""
all_name2valcount = comm.gather(local_name2valcount)
if comm.rank == 0:
name2sum = defaultdict(float)
name2count = defaultdict(float)
for n2vc in all_name2valcount:
for (name, (val, count)) in n2vc.items():
try:
val = float(val)
except ValueError:
if comm.rank == 0:
warnings.warn(f'WARNING: tried to compute mean on non-float {name}={val}')
else:
name2sum[name] += val * count
name2count[name] += count
return {name : name2sum[name] / name2count[name] for name in name2sum}
else:
return {}

View File

@@ -0,0 +1,31 @@
from baselines.common import mpi_util
from mpi4py import MPI
import subprocess
import sys
from baselines import logger
def helper_for_mpi_weighted_mean():
comm = MPI.COMM_WORLD
if comm.rank == 0:
name2valcount = {'a' : (10, 2), 'b' : (20,3)}
elif comm.rank == 1:
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
else:
raise NotImplementedError
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
if comm.rank == 0:
assert d == correctval, f'{d} != {correctval}'
for name, (val, count) in name2valcount.items():
for _ in range(count):
logger.logkv_mean(name, val)
d2 = logger.dumpkvs(mpi_mean=True)
if comm.rank == 0:
assert d2 == correctval
def test_mpi_weighted_mean():
subprocess.check_call(['mpirun', '-n', '2', sys.executable, '-c',
'from baselines.common import test_mpi_util; test_mpi_util.helper_for_mpi_weighted_mean()'])

View File

@@ -26,7 +26,6 @@ class IdentityEnv(Env):
self._choose_next_state()
done = False
if self.episode_len and self.time >= self.episode_len:
rew = 0
done = True
return self.state, rew, done, {}
@@ -74,7 +73,7 @@ class BoxIdentityEnv(IdentityEnv):
episode_len=None,
):
self.action_space = Box(low=-1.0, high=1.0, shape=shape)
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
super().__init__(episode_len=episode_len)
def _get_reward(self, actions):

View File

@@ -168,6 +168,19 @@ class VecEnvWrapper(VecEnv):
def get_images(self):
return self.venv.get_images()
class VecEnvObservationWrapper(VecEnvWrapper):
@abstractmethod
def process(self, obs):
pass
def reset(self):
obs = self.venv.reset()
return self.process(obs)
def step_wait(self):
obs, rews, dones, infos = self.venv.step_wait()
return self.process(obs), rews, dones, infos
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)

View File

@@ -27,7 +27,7 @@ class DummyVecEnv(VecEnv):
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
self.actions = None
self.specs = [e.spec for e in self.envs]
self.spec = self.envs[0].spec
def step_async(self, actions):
listify = True

View File

@@ -54,7 +54,6 @@ class ShmemVecEnv(VecEnv):
proc.start()
child_pipe.close()
self.waiting_step = False
self.specs = [f().spec for f in env_fns]
self.viewer = None
def reset(self):

View File

@@ -1,7 +1,9 @@
import numpy as np
from multiprocessing import Process, Pipe
import multiprocessing as mp
from . import VecEnv, CloudpickleWrapper
ctx = mp.get_context('spawn')
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
@@ -21,8 +23,8 @@ def worker(remote, parent_remote, env_fn_wrapper):
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'get_spaces_spec':
remote.send((env.observation_space, env.action_space, env.spec))
else:
raise NotImplementedError
except KeyboardInterrupt:
@@ -45,8 +47,8 @@ class SubprocVecEnv(VecEnv):
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
@@ -54,10 +56,9 @@ class SubprocVecEnv(VecEnv):
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv()
self.viewer = None
self.specs = [f().spec for f in env_fns]
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
@@ -99,9 +100,12 @@ class SubprocVecEnv(VecEnv):
def _assert_not_closed(self):
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
def __del__(self):
if not self.closed:
self.close()
def _flatten_obs(obs):
assert isinstance(obs, list) or isinstance(obs, tuple)
assert isinstance(obs, (list, tuple))
assert len(obs) > 0
if isinstance(obs[0], dict):
@@ -111,4 +115,3 @@ def _flatten_obs(obs):
return {k: np.stack([o[k] for o in obs]) for k in keys}
else:
return np.stack(obs)

View File

@@ -108,7 +108,7 @@ def learn(*, network, env, total_timesteps,
# Prepare params.
params = config.DEFAULT_PARAMS
env_name = env.specs[0].id
env_name = env.spec.id
params['env_name'] = env_name
params['replay_strategy'] = replay_strategy
if env_name in config.DEFAULT_ENV_PARAMS:

View File

@@ -68,7 +68,8 @@ class HumanOutputFormat(KVWriter, SeqWriter):
self.file.flush()
def _truncate(self, s):
return s[:20] + '...' if len(s) > 23 else s
maxlen = 30
return s[:maxlen-3] + '...' if len(s) > maxlen else s
def writeseq(self, seq):
seq = list(seq)
@@ -210,14 +211,17 @@ def logkvs(d):
for (k, v) in d.items():
logkv(k, v)
def dumpkvs():
def dumpkvs(mpi_mean=False):
"""
Write all of the diagnostics from the current iteration
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
mpi_mean: whether to average across MPI workers. mpi_mean=False just
has each worker write its own stats (and under default settings
non-root workers don't write anything), whereas mpi_mean=True has
the root worker collect all of the stats and write the average,
and no one else writes anything.
"""
Logger.CURRENT.dumpkvs()
return Logger.CURRENT.dumpkvs(mpi_mean=mpi_mean)
def getkvs():
return Logger.CURRENT.name2val
@@ -307,20 +311,30 @@ class Logger(object):
self.name2val[key] = val
def logkv_mean(self, key, val):
if val is None:
self.name2val[key] = None
return
oldval, cnt = self.name2val[key], self.name2cnt[key]
self.name2val[key] = oldval*cnt/(cnt+1) + val/(cnt+1)
self.name2cnt[key] = cnt + 1
def dumpkvs(self):
def dumpkvs(self, mpi_mean=False):
if self.level == DISABLED: return
if mpi_mean:
from baselines.common import mpi_util
from mpi4py import MPI
comm = MPI.COMM_WORLD
d = mpi_util.mpi_weighted_mean(comm,
{name : (val, self.name2cnt.get(name, 1))
for (name, val) in self.name2val.items()})
if comm.rank != 0:
d['dummy'] = 1 # so we don't get a warning about empty dict
else:
d = self.name2val
out = d.copy() # Return the dict for unit testing purposes
for fmt in self.output_formats:
if isinstance(fmt, KVWriter):
fmt.writekvs(self.name2val)
fmt.writekvs(d)
self.name2val.clear()
self.name2cnt.clear()
return out
def log(self, *args, level=INFO):
if self.level <= level:
@@ -456,7 +470,6 @@ def read_tb(path):
import pandas
import numpy as np
from glob import glob
from collections import defaultdict
import tensorflow as tf
if osp.isdir(path):
fnames = glob(osp.join(path, "events.*"))
@@ -485,5 +498,7 @@ def read_tb(path):
# configure the default logger on import
_configure_default_logger()
if __name__ == "__main__":
_demo()