Fix alien syntax and apply PEP 8 style (#554)
This commit is contained in:
committed by
pzhokhov
parent
b29c8020d7
commit
1937826784
@@ -1,5 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import gym
|
import gym
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@@ -30,7 +30,7 @@ for env in gym.envs.registry.all():
|
|||||||
# reading benchmark names directly from retro requires
|
# reading benchmark names directly from retro requires
|
||||||
# importing retro here, and for some reason that crashes tensorflow
|
# importing retro here, and for some reason that crashes tensorflow
|
||||||
# in ubuntu
|
# in ubuntu
|
||||||
_game_envs['retro'] = set([
|
_game_envs['retro'] = {
|
||||||
'BubbleBobble-Nes',
|
'BubbleBobble-Nes',
|
||||||
'SuperMarioBros-Nes',
|
'SuperMarioBros-Nes',
|
||||||
'TwinBee3PokoPokoDaimaou-Nes',
|
'TwinBee3PokoPokoDaimaou-Nes',
|
||||||
@@ -39,7 +39,7 @@ _game_envs['retro'] = set([
|
|||||||
'Vectorman-Genesis',
|
'Vectorman-Genesis',
|
||||||
'FinalFight-Snes',
|
'FinalFight-Snes',
|
||||||
'SpaceInvaders-Snes',
|
'SpaceInvaders-Snes',
|
||||||
])
|
}
|
||||||
|
|
||||||
|
|
||||||
def train(args, extra_args):
|
def train(args, extra_args):
|
||||||
@@ -60,8 +60,6 @@ def train(args, extra_args):
|
|||||||
if alg_kwargs.get('network') is None:
|
if alg_kwargs.get('network') is None:
|
||||||
alg_kwargs['network'] = get_default_network(env_type)
|
alg_kwargs['network'] = get_default_network(env_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
|
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
|
||||||
|
|
||||||
model = learn(
|
model = learn(
|
||||||
@@ -91,7 +89,7 @@ def build_env(args):
|
|||||||
if args.num_env:
|
if args.num_env:
|
||||||
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
|
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
|
||||||
else:
|
else:
|
||||||
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
||||||
|
|
||||||
env = VecNormalize(env)
|
env = VecNormalize(env)
|
||||||
|
|
||||||
@@ -117,7 +115,8 @@ def build_env(args):
|
|||||||
elif env_type == 'retro':
|
elif env_type == 'retro':
|
||||||
import retro
|
import retro
|
||||||
gamestate = args.gamestate or 'Level1-1'
|
gamestate = args.gamestate or 'Level1-1'
|
||||||
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE)
|
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
|
||||||
|
use_restricted_actions=retro.Actions.DISCRETE)
|
||||||
env.seed(args.seed)
|
env.seed(args.seed)
|
||||||
env = bench.Monitor(env, logger.get_dir())
|
env = bench.Monitor(env, logger.get_dir())
|
||||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||||
@@ -140,7 +139,7 @@ def build_env(args):
|
|||||||
def get_env_type(env_id):
|
def get_env_type(env_id):
|
||||||
if env_id in _game_envs.keys():
|
if env_id in _game_envs.keys():
|
||||||
env_type = env_id
|
env_type = env_id
|
||||||
env_id = [g for g in _game_envs[env_type]][0]
|
env_id = [g for g in _game_envs[env_type]][0]
|
||||||
else:
|
else:
|
||||||
env_type = None
|
env_type = None
|
||||||
for g, e in _game_envs.items():
|
for g, e in _game_envs.items():
|
||||||
@@ -151,6 +150,7 @@ def get_env_type(env_id):
|
|||||||
|
|
||||||
return env_type, env_id
|
return env_type, env_id
|
||||||
|
|
||||||
|
|
||||||
def get_default_network(env_type):
|
def get_default_network(env_type):
|
||||||
if env_type == 'mujoco' or env_type == 'classic_control':
|
if env_type == 'mujoco' or env_type == 'classic_control':
|
||||||
return 'mlp'
|
return 'mlp'
|
||||||
@@ -159,6 +159,7 @@ def get_default_network(env_type):
|
|||||||
|
|
||||||
raise ValueError('Unknown env_type {}'.format(env_type))
|
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||||
|
|
||||||
|
|
||||||
def get_alg_module(alg, submodule=None):
|
def get_alg_module(alg, submodule=None):
|
||||||
submodule = submodule or alg
|
submodule = submodule or alg
|
||||||
try:
|
try:
|
||||||
@@ -174,6 +175,7 @@ def get_alg_module(alg, submodule=None):
|
|||||||
def get_learn_function(alg):
|
def get_learn_function(alg):
|
||||||
return get_alg_module(alg).learn
|
return get_alg_module(alg).learn
|
||||||
|
|
||||||
|
|
||||||
def get_learn_function_defaults(alg, env_type):
|
def get_learn_function_defaults(alg, env_type):
|
||||||
try:
|
try:
|
||||||
alg_defaults = get_alg_module(alg, 'defaults')
|
alg_defaults = get_alg_module(alg, 'defaults')
|
||||||
@@ -182,6 +184,7 @@ def get_learn_function_defaults(alg, env_type):
|
|||||||
kwargs = {}
|
kwargs = {}
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
def parse(v):
|
def parse(v):
|
||||||
'''
|
'''
|
||||||
convert value of a command-line arg to a python object if possible, othewise, keep as string
|
convert value of a command-line arg to a python object if possible, othewise, keep as string
|
||||||
@@ -199,14 +202,13 @@ def main():
|
|||||||
|
|
||||||
arg_parser = common_arg_parser()
|
arg_parser = common_arg_parser()
|
||||||
args, unknown_args = arg_parser.parse_known_args()
|
args, unknown_args = arg_parser.parse_known_args()
|
||||||
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
|
extra_args = {k: parse(v) for k, v in parse_unknown_args(unknown_args).items()}
|
||||||
|
|
||||||
|
|
||||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
rank = 0
|
rank = 0
|
||||||
logger.configure()
|
logger.configure()
|
||||||
else:
|
else:
|
||||||
logger.configure(format_strs = [])
|
logger.configure(format_strs=[])
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
|
||||||
model, _ = train(args, extra_args)
|
model, _ = train(args, extra_args)
|
||||||
@@ -215,14 +217,13 @@ def main():
|
|||||||
save_path = osp.expanduser(args.save_path)
|
save_path = osp.expanduser(args.save_path)
|
||||||
model.save(save_path)
|
model.save(save_path)
|
||||||
|
|
||||||
|
|
||||||
if args.play:
|
if args.play:
|
||||||
logger.log("Running trained model")
|
logger.log("Running trained model")
|
||||||
env = build_env(args)
|
env = build_env(args)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
while True:
|
while True:
|
||||||
actions = model.step(obs)[0]
|
actions = model.step(obs)[0]
|
||||||
obs, _, done, _ = env.step(actions)
|
obs, _, done, _ = env.step(actions)
|
||||||
env.render()
|
env.render()
|
||||||
done = done.any() if isinstance(done, np.ndarray) else done
|
done = done.any() if isinstance(done, np.ndarray) else done
|
||||||
|
|
||||||
@@ -230,6 +231,5 @@ def main():
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Reference in New Issue
Block a user