Files
baselines/baselines/run.py
Peter Zhokhov 35dcb6fd74 merged internal
2018-10-22 19:22:46 -07:00

208 lines
5.5 KiB
Python

import sys
import multiprocessing
import os.path as osp
import gym
from collections import defaultdict
import numpy as np
from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, env_thunk
from baselines import logger
from baselines.registry import registry
try:
from mpi4py import MPI
except ImportError:
MPI = None
try:
import pybullet_envs
except ImportError:
pybullet_envs = None
try:
import roboschool
except ImportError:
roboschool = None
_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
# TODO: solve this with regexes
env_type = env._entry_point.split(':')[0].split('.')[-1]
_game_envs[env_type].add(env.id)
# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = {
'BubbleBobble-Nes',
'SuperMarioBros-Nes',
'TwinBee3PokoPokoDaimaou-Nes',
'SpaceHarrier-Nes',
'SonicTheHedgehog-Genesis',
'Vectorman-Genesis',
'FinalFight-Snes',
'SpaceInvaders-Snes',
}
def train(args, extra_args):
env_type, env_id = get_env_type(args.env)
print('env_type: {}'.format(env_type))
total_timesteps = int(args.num_timesteps)
seed = args.seed
learn = get_learn_function(args.alg)
alg_kwargs = get_learn_function_defaults(args.alg, env_type)
alg_kwargs.update(extra_args)
env = build_env(args)
if args.network:
alg_kwargs['network'] = args.network
else:
if alg_kwargs.get('network') is None:
alg_kwargs['network'] = get_default_network(env_type)
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
model = learn(
env=env,
seed=seed,
total_timesteps=total_timesteps,
**alg_kwargs
)
return model, env
def build_env(args):
ncpu = multiprocessing.cpu_count()
if sys.platform == 'darwin': ncpu //= 2
nenv = args.num_env or ncpu
alg = args.alg
seed = args.seed
env_type, env_id = get_env_type(args.env)
assert alg in registry, 'Unknown algorithm {}'.format(alg)
if env_type in {'atari', 'retro'}:
frame_stack_size = 4
else:
frame_stack_size = 1
if registry[alg]['supports_vecenv']:
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale, frame_stack_size=frame_stack_size)
else:
env = env_thunk(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': frame_stack_size > 1})
if env_type == 'mujoco' and registry[alg]['supports_vecenv']:
env = VecNormalize(env)
return env
def get_env_type(env_id):
if env_id in _game_envs.keys():
env_type = env_id
env_id = [g for g in _game_envs[env_type]][0]
else:
env_type = None
for g, e in _game_envs.items():
if env_id in e:
env_type = g
break
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
return env_type, env_id
def get_default_network(env_type):
if env_type == 'atari':
return 'cnn'
else:
return 'mlp'
def get_alg_module(alg, submodule=None):
import inspect
entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
module = inspect.getmodule(entry['fn']).__name__
if submodule is not None:
module = '.'.join([module, submodule])
return module
def get_learn_function(alg):
entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
return entry['fn']
def get_learn_function_defaults(alg, env_type):
entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
return entry['defaults'].get(env_type, {})
def parse_cmdline_kwargs(args):
'''
convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
'''
def parse(v):
assert isinstance(v, str)
try:
return eval(v)
except (NameError, SyntaxError):
return v
return {k: parse(v) for k,v in parse_unknown_args(args).items()}
def main():
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args()
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
else:
logger.configure(format_strs=[])
rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args)
env.close()
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
env = build_env(args)
obs = env.reset()
def initialize_placeholders(nlstm=128,**kwargs):
return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
state, dones = initialize_placeholders(**extra_args)
while True:
actions, _, state, _ = model.step(obs,S=state, M=dones)
obs, _, done, _ = env.step(actions)
env.render()
done = done.any() if isinstance(done, np.ndarray) else done
if done:
obs = env.reset()
env.close()
if __name__ == '__main__':
main()