Fix alien syntax and apply PEP 8 style (#554)

This commit is contained in:
Alfredo Canziani
2018-08-30 20:21:25 -04:00
committed by pzhokhov
parent b29c8020d7
commit 1937826784

View File

@@ -30,7 +30,7 @@ for env in gym.envs.registry.all():
# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = set([
_game_envs['retro'] = {
'BubbleBobble-Nes',
'SuperMarioBros-Nes',
'TwinBee3PokoPokoDaimaou-Nes',
@@ -39,7 +39,7 @@ _game_envs['retro'] = set([
'Vectorman-Genesis',
'FinalFight-Snes',
'SpaceInvaders-Snes',
])
}
def train(args, extra_args):
@@ -60,8 +60,6 @@ def train(args, extra_args):
if alg_kwargs.get('network') is None:
alg_kwargs['network'] = get_default_network(env_type)
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
model = learn(
@@ -117,7 +115,8 @@ def build_env(args):
elif env_type == 'retro':
import retro
gamestate = args.gamestate or 'Level1-1'
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE)
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
use_restricted_actions=retro.Actions.DISCRETE)
env.seed(args.seed)
env = bench.Monitor(env, logger.get_dir())
env = retro_wrappers.wrap_deepmind_retro(env)
@@ -151,6 +150,7 @@ def get_env_type(env_id):
return env_type, env_id
def get_default_network(env_type):
if env_type == 'mujoco' or env_type == 'classic_control':
return 'mlp'
@@ -159,6 +159,7 @@ def get_default_network(env_type):
raise ValueError('Unknown env_type {}'.format(env_type))
def get_alg_module(alg, submodule=None):
submodule = submodule or alg
try:
@@ -174,6 +175,7 @@ def get_alg_module(alg, submodule=None):
def get_learn_function(alg):
return get_alg_module(alg).learn
def get_learn_function_defaults(alg, env_type):
try:
alg_defaults = get_alg_module(alg, 'defaults')
@@ -182,6 +184,7 @@ def get_learn_function_defaults(alg, env_type):
kwargs = {}
return kwargs
def parse(v):
'''
convert value of a command-line arg to a python object if possible, othewise, keep as string
@@ -199,14 +202,13 @@ def main():
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args()
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
extra_args = {k: parse(v) for k, v in parse_unknown_args(unknown_args).items()}
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
else:
logger.configure(format_strs = [])
logger.configure(format_strs=[])
rank = MPI.COMM_WORLD.Get_rank()
model, _ = train(args, extra_args)
@@ -215,7 +217,6 @@ def main():
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
env = build_env(args)
@@ -230,6 +231,5 @@ def main():
obs = env.reset()
if __name__ == '__main__':
main()