git subrepo pull (merge) baselines
subrepo: subdir: "baselines" merged: "39f8be8f" upstream: origin: "git@github.com:openai/baselines.git" branch: "master" commit: "0a40206c" git-subrepo: version: "0.4.0" origin: "git@github.com:ingydotnet/git-subrepo.git" commit: "74339e8"
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import sys
|
||||
import multiprocessing
|
||||
import multiprocessing
|
||||
import os.path as osp
|
||||
import gym
|
||||
from collections import defaultdict
|
||||
@@ -30,7 +30,7 @@ for env in gym.envs.registry.all():
|
||||
# reading benchmark names directly from retro requires
|
||||
# importing retro here, and for some reason that crashes tensorflow
|
||||
# in ubuntu
|
||||
_game_envs['retro'] = set([
|
||||
_game_envs['retro'] = {
|
||||
'BubbleBobble-Nes',
|
||||
'SuperMarioBros-Nes',
|
||||
'TwinBee3PokoPokoDaimaou-Nes',
|
||||
@@ -39,7 +39,7 @@ _game_envs['retro'] = set([
|
||||
'Vectorman-Genesis',
|
||||
'FinalFight-Snes',
|
||||
'SpaceInvaders-Snes',
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
def train(args, extra_args):
|
||||
@@ -60,8 +60,6 @@ def train(args, extra_args):
|
||||
if alg_kwargs.get('network') is None:
|
||||
alg_kwargs['network'] = get_default_network(env_type)
|
||||
|
||||
|
||||
|
||||
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
|
||||
|
||||
model = learn(
|
||||
@@ -76,7 +74,7 @@ def train(args, extra_args):
|
||||
|
||||
def build_env(args):
|
||||
ncpu = multiprocessing.cpu_count()
|
||||
if sys.platform == 'darwin': ncpu /= 2
|
||||
if sys.platform == 'darwin': ncpu //= 2
|
||||
nenv = args.num_env or ncpu
|
||||
alg = args.alg
|
||||
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
@@ -91,7 +89,7 @@ def build_env(args):
|
||||
if args.num_env:
|
||||
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
|
||||
else:
|
||||
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
||||
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
||||
|
||||
env = VecNormalize(env)
|
||||
|
||||
@@ -117,7 +115,8 @@ def build_env(args):
|
||||
elif env_type == 'retro':
|
||||
import retro
|
||||
gamestate = args.gamestate or 'Level1-1'
|
||||
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE)
|
||||
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
|
||||
use_restricted_actions=retro.Actions.DISCRETE)
|
||||
env.seed(args.seed)
|
||||
env = bench.Monitor(env, logger.get_dir())
|
||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||
@@ -140,7 +139,7 @@ def build_env(args):
|
||||
def get_env_type(env_id):
|
||||
if env_id in _game_envs.keys():
|
||||
env_type = env_id
|
||||
env_id = [g for g in _game_envs[env_type]][0]
|
||||
env_id = [g for g in _game_envs[env_type]][0]
|
||||
else:
|
||||
env_type = None
|
||||
for g, e in _game_envs.items():
|
||||
@@ -151,6 +150,7 @@ def get_env_type(env_id):
|
||||
|
||||
return env_type, env_id
|
||||
|
||||
|
||||
def get_default_network(env_type):
|
||||
if env_type == 'mujoco' or env_type == 'classic_control':
|
||||
return 'mlp'
|
||||
@@ -159,6 +159,7 @@ def get_default_network(env_type):
|
||||
|
||||
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||
|
||||
|
||||
def get_alg_module(alg, submodule=None):
|
||||
submodule = submodule or alg
|
||||
try:
|
||||
@@ -174,6 +175,7 @@ def get_alg_module(alg, submodule=None):
|
||||
def get_learn_function(alg):
|
||||
return get_alg_module(alg).learn
|
||||
|
||||
|
||||
def get_learn_function_defaults(alg, env_type):
|
||||
try:
|
||||
alg_defaults = get_alg_module(alg, 'defaults')
|
||||
@@ -182,6 +184,7 @@ def get_learn_function_defaults(alg, env_type):
|
||||
kwargs = {}
|
||||
return kwargs
|
||||
|
||||
|
||||
def parse(v):
|
||||
'''
|
||||
convert value of a command-line arg to a python object if possible, othewise, keep as string
|
||||
@@ -199,14 +202,13 @@ def main():
|
||||
|
||||
arg_parser = common_arg_parser()
|
||||
args, unknown_args = arg_parser.parse_known_args()
|
||||
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
|
||||
|
||||
extra_args = {k: parse(v) for k, v in parse_unknown_args(unknown_args).items()}
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
rank = 0
|
||||
logger.configure()
|
||||
else:
|
||||
logger.configure(format_strs = [])
|
||||
logger.configure(format_strs=[])
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
model, _ = train(args, extra_args)
|
||||
@@ -215,14 +217,13 @@ def main():
|
||||
save_path = osp.expanduser(args.save_path)
|
||||
model.save(save_path)
|
||||
|
||||
|
||||
if args.play:
|
||||
logger.log("Running trained model")
|
||||
env = build_env(args)
|
||||
obs = env.reset()
|
||||
while True:
|
||||
actions = model.step(obs)[0]
|
||||
obs, _, done, _ = env.step(actions)
|
||||
obs, _, done, _ = env.step(actions)
|
||||
env.render()
|
||||
done = done.any() if isinstance(done, np.ndarray) else done
|
||||
|
||||
@@ -230,6 +231,5 @@ def main():
|
||||
obs = env.reset()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Reference in New Issue
Block a user