* add some docstrings * start making big changes * state machine redesign * sampling seems to work * some reorg * fixed sampling of real vals * json conversion * made it possible to register new commands got nontrivial version of Pred working * consolidate command definitions * add more macro blocks * revived visualization * rename Userdata -> CmdInterpreter make AlgoSmInstance subclass of SmInstance that uses appropriate userdata argument * replace userdata by ci when appropriate * minor test fixes * revamped handmade dir, can run ppo_metal * seed to avoid random test failure * implement AlgoAgent * Autogenerated object that performs all ops and macros * more CmdRecorder changes * move files around * move MatchProb and JtftProb * remove obsolete * fix tests involving AlgoAgent (pending the next commit on ppo_metal code) * ppo_metal: reduce duplication in policy_gen, make sess an attribute of PpoAgent and StochasticPolicy instead of using get_default_session everywhere. * maze_env reformatting, move algo_search script (but stil broken) * move agent.py * fix test on handcrafted agents * tuning/fixing ppo_metal baseline * minor * Fix ppo_metal baseline * Don’t set epcount, tcount unless they’re being used * get rid of old ppo_metal baseline * fixes for handmade/run.py tuning * fix codegen ppo * fix handmade ppo hps * fix test, go back to safe_div * switch to more complex filtering * make sure all handcrafted algos have finite probability * train to maximize logprob of provided samples Trex changes to avoid segfault * AlgoSm also includes global hyperparams * don’t duplicate global hyperparam defaults * create generic_ob_ac_space function * use sorted list of outkeys * revive tsne * todo changes * determinism test * todo + test fix * remove a few deprecated files, rename other tests so they don’t run automatically, fix real test failure * continuous control with codegen * continuous control with codegen * implement continuous action space algodistr * ppo with trex RUN BENCHMARKS * wrap trex in a monitor * dummy commit to RUN BENCHMARKS * adding monitor to trex env RUN BENCHMARKS * adding monitor to trex RUN BENCHMARKS * include monitor into trex env RUN BENCHMARKS * generate nll and predmean using Distribution node * dummy commit to RUN BENCHMARKS * include pybullet into baselines optional dependencies * dummy commit to RUN BENCHMARKS * install games for cron rcall user RUN BENCHMARKS * add --yes flag to install.py in rcall config for cron user RUN BENCHMARKS * both continuous and discrete versions seem to run * fixes to monitor to work with vecenv-like info and rewards RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * removed shape check from one-hot encoding logic in distributions.CategoricalPd * reset logger configuration in codegen/handmade/run.py to be in-line with baselines RUN BENCHMARKS * merged peterz_codegen_benchmarks RUN BENCHMARKS * skip tests RUN BENCHMARKS * working on test failures * save benchmark dicts RUN BENCHMARK * merged peterz_codegen_benchmark RUN BENCHMARKS * add get_git_commit_message to the baselines.common.console_util * dummy commit to RUN BENCHMARKS * merged fixes from peterz_codegen_benchmark RUN BENCHMARKS * fixing failure in test_algo_nll WIP * test_algo_nll passes with both ppo and softq * re-enabled tests * run trex on gpus for 100k total (horizon=100k / 16) RUN BENCHMARKS * merged latest peterz_codegen_benchmarks RUN BENCHMARKS * fixing codegen test failures (logging-related) * fixed name collision in run-benchmarks-new.py RUN BENCHMARKS * fixed name collision in run-benchmarks-new.py RUN BENCHMARKS * fixed import in node_filters.py * test_algo_search passes * some cleanup * dummy commit to RUN BENCHMARKS * merge fast fail for subprocvecenv RUN BENCHMARKS * use SubprocVecEnv in sonic_prob * added deprecation note to shmem_vec_env * allow indexing of distributions * add timeout to pipeline.yaml * typo in pipeline.yml * run tests with --forked option * resolved merge conflict in rl_algs.bench.benchmarks * re-enable parallel tests * fix remaining merge conflicts and syntax * Update trex_prob.py * fixes to ResultsWriter * take baselines/run.py from peterz_codegen branch * actually save stuff to file in VecMonitor RUN BENCHMARKS * enable parallel tests * merge stricter flake8 * merge peterz_codegen_benchmark, resolve conflicts * autopep8 * remove traces of Monitor from trex env, check shapes before encoding in CategoricalPd * asserts and warnings to make q -> distribution change more explicit * fixed assert in CategoricalPd * add header to vec_monitor output file RUN BENCHMARKS * make VecMonitor write header to the output file * remove deprecation message from shmem_vec_env RUN BENCHMARKS * autopep8 * proper shape test in distributions.py * ResultsWriter can take dict headers * dummy commit to RUN BENCHMARKS * replace assert len(qs)==1 with warning RUN BENCHMARKS * removed pdb from ppo2 RUN BENCHMARKS
237 lines
6.6 KiB
Python
237 lines
6.6 KiB
Python
import sys
|
|
import multiprocessing
|
|
import os.path as osp
|
|
import gym
|
|
from collections import defaultdict
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
|
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
|
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
|
|
from baselines.common.tf_util import get_session
|
|
from baselines import bench, logger
|
|
from importlib import import_module
|
|
|
|
from baselines.common.vec_env.vec_normalize import VecNormalize
|
|
from baselines.common import atari_wrappers, retro_wrappers
|
|
|
|
try:
|
|
from mpi4py import MPI
|
|
except ImportError:
|
|
MPI = None
|
|
|
|
try:
|
|
import pybullet_envs
|
|
except ImportError:
|
|
pybullet_envs = None
|
|
|
|
try:
|
|
import roboschool
|
|
except ImportError:
|
|
roboschool = None
|
|
|
|
_game_envs = defaultdict(set)
|
|
for env in gym.envs.registry.all():
|
|
# TODO: solve this with regexes
|
|
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
|
_game_envs[env_type].add(env.id)
|
|
|
|
# reading benchmark names directly from retro requires
|
|
# importing retro here, and for some reason that crashes tensorflow
|
|
# in ubuntu
|
|
_game_envs['retro'] = {
|
|
'BubbleBobble-Nes',
|
|
'SuperMarioBros-Nes',
|
|
'TwinBee3PokoPokoDaimaou-Nes',
|
|
'SpaceHarrier-Nes',
|
|
'SonicTheHedgehog-Genesis',
|
|
'Vectorman-Genesis',
|
|
'FinalFight-Snes',
|
|
'SpaceInvaders-Snes',
|
|
}
|
|
|
|
|
|
def train(args, extra_args):
|
|
env_type, env_id = get_env_type(args.env)
|
|
print('env_type: {}'.format(env_type))
|
|
|
|
total_timesteps = int(args.num_timesteps)
|
|
seed = args.seed
|
|
|
|
learn = get_learn_function(args.alg)
|
|
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
|
|
alg_kwargs.update(extra_args)
|
|
|
|
env = build_env(args)
|
|
|
|
if args.network:
|
|
alg_kwargs['network'] = args.network
|
|
else:
|
|
if alg_kwargs.get('network') is None:
|
|
alg_kwargs['network'] = get_default_network(env_type)
|
|
|
|
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
|
|
|
|
model = learn(
|
|
env=env,
|
|
seed=seed,
|
|
total_timesteps=total_timesteps,
|
|
**alg_kwargs
|
|
)
|
|
|
|
return model, env
|
|
|
|
|
|
def build_env(args):
|
|
ncpu = multiprocessing.cpu_count()
|
|
if sys.platform == 'darwin': ncpu //= 2
|
|
nenv = args.num_env or ncpu
|
|
alg = args.alg
|
|
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
|
seed = args.seed
|
|
|
|
env_type, env_id = get_env_type(args.env)
|
|
|
|
if env_type == 'atari':
|
|
if alg == 'acer':
|
|
env = make_vec_env(env_id, env_type, nenv, seed)
|
|
elif alg == 'deepq':
|
|
env = atari_wrappers.make_atari(env_id)
|
|
env.seed(seed)
|
|
env = bench.Monitor(env, logger.get_dir())
|
|
env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True)
|
|
elif alg == 'trpo_mpi':
|
|
env = atari_wrappers.make_atari(env_id)
|
|
env.seed(seed)
|
|
env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
|
|
env = atari_wrappers.wrap_deepmind(env)
|
|
# TODO check if the second seeding is necessary, and eventually remove
|
|
env.seed(seed)
|
|
else:
|
|
frame_stack_size = 4
|
|
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
|
|
|
|
elif env_type == 'retro':
|
|
import retro
|
|
gamestate = args.gamestate or 'Level1-1'
|
|
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
|
|
use_restricted_actions=retro.Actions.DISCRETE)
|
|
env.seed(args.seed)
|
|
env = bench.Monitor(env, logger.get_dir())
|
|
env = retro_wrappers.wrap_deepmind_retro(env)
|
|
|
|
else:
|
|
get_session(tf.ConfigProto(allow_soft_placement=True,
|
|
intra_op_parallelism_threads=1,
|
|
inter_op_parallelism_threads=1))
|
|
|
|
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
|
|
|
|
if env_type == 'mujoco':
|
|
env = VecNormalize(env)
|
|
|
|
return env
|
|
|
|
|
|
def get_env_type(env_id):
|
|
if env_id in _game_envs.keys():
|
|
env_type = env_id
|
|
env_id = [g for g in _game_envs[env_type]][0]
|
|
else:
|
|
env_type = None
|
|
for g, e in _game_envs.items():
|
|
if env_id in e:
|
|
env_type = g
|
|
break
|
|
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
|
|
|
|
return env_type, env_id
|
|
|
|
|
|
def get_default_network(env_type):
|
|
if env_type == 'atari':
|
|
return 'cnn'
|
|
else:
|
|
return 'mlp'
|
|
|
|
def get_alg_module(alg, submodule=None):
|
|
submodule = submodule or alg
|
|
try:
|
|
# first try to import the alg module from baselines
|
|
alg_module = import_module('.'.join(['baselines', alg, submodule]))
|
|
except ImportError:
|
|
# then from rl_algs
|
|
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
|
|
|
|
return alg_module
|
|
|
|
|
|
def get_learn_function(alg):
|
|
return get_alg_module(alg).learn
|
|
|
|
|
|
def get_learn_function_defaults(alg, env_type):
|
|
try:
|
|
alg_defaults = get_alg_module(alg, 'defaults')
|
|
kwargs = getattr(alg_defaults, env_type)()
|
|
except (ImportError, AttributeError):
|
|
kwargs = {}
|
|
return kwargs
|
|
|
|
|
|
|
|
def parse_cmdline_kwargs(args):
|
|
'''
|
|
convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
|
|
'''
|
|
def parse(v):
|
|
|
|
assert isinstance(v, str)
|
|
try:
|
|
return eval(v)
|
|
except (NameError, SyntaxError):
|
|
return v
|
|
|
|
return {k: parse(v) for k,v in parse_unknown_args(args).items()}
|
|
|
|
|
|
|
|
def main():
|
|
# configure logger, disable logging in child MPI processes (with rank > 0)
|
|
|
|
arg_parser = common_arg_parser()
|
|
args, unknown_args = arg_parser.parse_known_args()
|
|
extra_args = parse_cmdline_kwargs(unknown_args)
|
|
|
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
|
rank = 0
|
|
logger.configure()
|
|
else:
|
|
logger.configure(format_strs=[])
|
|
rank = MPI.COMM_WORLD.Get_rank()
|
|
|
|
model, env = train(args, extra_args)
|
|
env.close()
|
|
|
|
if args.save_path is not None and rank == 0:
|
|
save_path = osp.expanduser(args.save_path)
|
|
model.save(save_path)
|
|
|
|
if args.play:
|
|
logger.log("Running trained model")
|
|
env = build_env(args)
|
|
obs = env.reset()
|
|
while True:
|
|
actions = model.step(obs)[0]
|
|
obs, _, done, _ = env.step(actions)
|
|
env.render()
|
|
done = done.any() if isinstance(done, np.ndarray) else done
|
|
|
|
if done:
|
|
obs = env.reset()
|
|
|
|
env.close()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|