Files
baselines/baselines/run.py
pzhokhov fe06c6b4db continuous action spaces for codegen + some benchmarking (#82)
* add some docstrings

* start making big changes

* state machine redesign

* sampling seems to work

* some reorg

* fixed sampling of real vals

* json conversion

* made it possible to register new commands
got nontrivial version of Pred working

* consolidate command definitions

* add more macro blocks

* revived visualization

* rename Userdata -> CmdInterpreter
make AlgoSmInstance subclass of SmInstance that uses appropriate userdata argument

* replace userdata by ci when appropriate

* minor test fixes

* revamped handmade dir, can run ppo_metal

* seed to avoid random test failure

* implement AlgoAgent

* Autogenerated object that performs all ops and macros

* more CmdRecorder changes

* move files around

* move MatchProb and JtftProb

* remove obsolete

* fix tests involving AlgoAgent (pending the next commit on ppo_metal code)

* ppo_metal: reduce duplication in policy_gen, make sess an attribute of PpoAgent and StochasticPolicy instead of using get_default_session everywhere.

* maze_env reformatting, move algo_search script (but stil broken)

* move agent.py

* fix test on handcrafted agents

* tuning/fixing ppo_metal baseline

* minor

* Fix ppo_metal baseline

* Don’t set epcount, tcount unless they’re being used

* get rid of old ppo_metal baseline

* fixes for handmade/run.py tuning

* fix codegen ppo

* fix handmade ppo hps

* fix test, go back to safe_div

* switch to more complex filtering

* make sure all handcrafted algos have finite probability

* train to maximize logprob of provided samples
Trex changes to avoid segfault

* AlgoSm also includes global hyperparams

* don’t duplicate global hyperparam defaults

* create generic_ob_ac_space function

* use sorted list of outkeys

* revive tsne

* todo changes

* determinism test

* todo + test fix

* remove a few deprecated files, rename other tests so they don’t run automatically, fix real test failure

* continuous control with codegen

* continuous control with codegen

* implement continuous action space algodistr

* ppo with trex RUN BENCHMARKS

* wrap trex in a monitor

* dummy commit to RUN BENCHMARKS

* adding monitor to trex env RUN BENCHMARKS

* adding monitor to trex RUN BENCHMARKS

* include monitor into trex env RUN BENCHMARKS

* generate nll and predmean using Distribution node

* dummy commit to RUN BENCHMARKS

* include pybullet into baselines optional dependencies

* dummy commit to RUN BENCHMARKS

* install games for cron rcall user RUN BENCHMARKS

* add --yes flag to install.py in rcall config for cron user RUN BENCHMARKS

* both continuous and discrete versions seem to run

* fixes to monitor to work with vecenv-like info and rewards RUN BENCHMARKS

* dummy commit to RUN BENCHMARKS

* removed shape check from one-hot encoding logic in distributions.CategoricalPd

* reset logger configuration in codegen/handmade/run.py to be in-line with baselines RUN BENCHMARKS

* merged peterz_codegen_benchmarks RUN BENCHMARKS

* skip tests RUN BENCHMARKS

* working on test failures

* save benchmark dicts RUN BENCHMARK

* merged peterz_codegen_benchmark RUN BENCHMARKS

* add get_git_commit_message to the baselines.common.console_util

* dummy commit to RUN BENCHMARKS

* merged fixes from peterz_codegen_benchmark RUN BENCHMARKS

* fixing failure in test_algo_nll WIP

* test_algo_nll passes with both ppo and softq

* re-enabled tests

* run trex on gpus for 100k total (horizon=100k / 16) RUN BENCHMARKS

* merged latest peterz_codegen_benchmarks RUN BENCHMARKS

* fixing codegen test failures (logging-related)

* fixed name collision in run-benchmarks-new.py RUN BENCHMARKS

* fixed name collision in run-benchmarks-new.py RUN BENCHMARKS

* fixed import in node_filters.py

* test_algo_search passes

* some cleanup

* dummy commit to RUN BENCHMARKS

* merge fast fail for subprocvecenv RUN BENCHMARKS

* use SubprocVecEnv in sonic_prob

* added deprecation note to shmem_vec_env

* allow indexing of distributions

* add timeout to pipeline.yaml

* typo in pipeline.yml

* run tests with --forked option

* resolved merge conflict in rl_algs.bench.benchmarks

* re-enable parallel tests

* fix remaining merge conflicts and syntax

* Update trex_prob.py

* fixes to ResultsWriter

* take baselines/run.py from peterz_codegen branch

* actually save stuff to file in VecMonitor RUN BENCHMARKS

* enable parallel tests

* merge stricter flake8

* merge peterz_codegen_benchmark, resolve conflicts

* autopep8

* remove traces of Monitor from trex env, check shapes before encoding in CategoricalPd

* asserts and warnings to make q -> distribution change more explicit

* fixed assert in CategoricalPd

* add header to vec_monitor output file RUN BENCHMARKS

* make VecMonitor write header to the output file

* remove deprecation message from shmem_vec_env RUN BENCHMARKS

* autopep8

* proper shape test in distributions.py

* ResultsWriter can take dict headers

* dummy commit to RUN BENCHMARKS

* replace assert len(qs)==1 with warning RUN BENCHMARKS

* removed pdb from ppo2 RUN BENCHMARKS
2018-09-14 15:43:49 -07:00

237 lines
6.6 KiB
Python

import sys
import multiprocessing
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
from baselines.common.tf_util import get_session
from baselines import bench, logger
from importlib import import_module
from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common import atari_wrappers, retro_wrappers
try:
from mpi4py import MPI
except ImportError:
MPI = None
try:
import pybullet_envs
except ImportError:
pybullet_envs = None
try:
import roboschool
except ImportError:
roboschool = None
_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
# TODO: solve this with regexes
env_type = env._entry_point.split(':')[0].split('.')[-1]
_game_envs[env_type].add(env.id)
# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
'BubbleBobble-Nes',
'SuperMarioBros-Nes',
'TwinBee3PokoPokoDaimaou-Nes',
'SpaceHarrier-Nes',
'SonicTheHedgehog-Genesis',
'Vectorman-Genesis',
'FinalFight-Snes',
'SpaceInvaders-Snes',
}
def train(args, extra_args):
env_type, env_id = get_env_type(args.env)
print('env_type: {}'.format(env_type))
total_timesteps = int(args.num_timesteps)
seed = args.seed
learn = get_learn_function(args.alg)
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
alg_kwargs.update(extra_args)
env = build_env(args)
if args.network:
alg_kwargs['network'] = args.network
else:
if alg_kwargs.get('network') is None:
alg_kwargs['network'] = get_default_network(env_type)
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
model = learn(
env=env,
seed=seed,
total_timesteps=total_timesteps,
**alg_kwargs
)
return model, env
def build_env(args):
ncpu = multiprocessing.cpu_count()
if sys.platform == 'darwin': ncpu //= 2
nenv = args.num_env or ncpu
alg = args.alg
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
seed = args.seed
env_type, env_id = get_env_type(args.env)
if env_type == 'atari':
if alg == 'acer':
env = make_vec_env(env_id, env_type, nenv, seed)
elif alg == 'deepq':
env = atari_wrappers.make_atari(env_id)
env.seed(seed)
env = bench.Monitor(env, logger.get_dir())
env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True)
elif alg == 'trpo_mpi':
env = atari_wrappers.make_atari(env_id)
env.seed(seed)
env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
env = atari_wrappers.wrap_deepmind(env)
# TODO check if the second seeding is necessary, and eventually remove
env.seed(seed)
else:
frame_stack_size = 4
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
elif env_type == 'retro':
import retro
gamestate = args.gamestate or 'Level1-1'
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
use_restricted_actions=retro.Actions.DISCRETE)
env.seed(args.seed)
env = bench.Monitor(env, logger.get_dir())
env = retro_wrappers.wrap_deepmind_retro(env)
else:
get_session(tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1))
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
if env_type == 'mujoco':
env = VecNormalize(env)
return env
def get_env_type(env_id):
if env_id in _game_envs.keys():
env_type = env_id
env_id = [g for g in _game_envs[env_type]][0]
else:
env_type = None
for g, e in _game_envs.items():
if env_id in e:
env_type = g
break
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
return env_type, env_id
def get_default_network(env_type):
if env_type == 'atari':
return 'cnn'
else:
return 'mlp'
def get_alg_module(alg, submodule=None):
submodule = submodule or alg
try:
# first try to import the alg module from baselines
alg_module = import_module('.'.join(['baselines', alg, submodule]))
except ImportError:
# then from rl_algs
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
return alg_module
def get_learn_function(alg):
return get_alg_module(alg).learn
def get_learn_function_defaults(alg, env_type):
try:
alg_defaults = get_alg_module(alg, 'defaults')
kwargs = getattr(alg_defaults, env_type)()
except (ImportError, AttributeError):
kwargs = {}
return kwargs
def parse_cmdline_kwargs(args):
'''
convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
'''
def parse(v):
assert isinstance(v, str)
try:
return eval(v)
except (NameError, SyntaxError):
return v
return {k: parse(v) for k,v in parse_unknown_args(args).items()}
def main():
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args()
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
else:
logger.configure(format_strs=[])
rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args)
env.close()
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
env = build_env(args)
obs = env.reset()
while True:
actions = model.step(obs)[0]
obs, _, done, _ = env.step(actions)
env.render()
done = done.any() if isinstance(done, np.ndarray) else done
if done:
obs = env.reset()
env.close()
if __name__ == '__main__':
main()