Fix alien syntax and apply PEP 8 style (#554)
This commit is contained in:
committed by
pzhokhov
parent
b29c8020d7
commit
1937826784
@@ -30,7 +30,7 @@ for env in gym.envs.registry.all():
|
||||
# reading benchmark names directly from retro requires
|
||||
# importing retro here, and for some reason that crashes tensorflow
|
||||
# in ubuntu
|
||||
_game_envs['retro'] = set([
|
||||
_game_envs['retro'] = {
|
||||
'BubbleBobble-Nes',
|
||||
'SuperMarioBros-Nes',
|
||||
'TwinBee3PokoPokoDaimaou-Nes',
|
||||
@@ -39,7 +39,7 @@ _game_envs['retro'] = set([
|
||||
'Vectorman-Genesis',
|
||||
'FinalFight-Snes',
|
||||
'SpaceInvaders-Snes',
|
||||
])
|
||||
}
|
||||
|
||||
|
||||
def train(args, extra_args):
|
||||
@@ -60,8 +60,6 @@ def train(args, extra_args):
|
||||
if alg_kwargs.get('network') is None:
|
||||
alg_kwargs['network'] = get_default_network(env_type)
|
||||
|
||||
|
||||
|
||||
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
|
||||
|
||||
model = learn(
|
||||
@@ -117,7 +115,8 @@ def build_env(args):
|
||||
elif env_type == 'retro':
|
||||
import retro
|
||||
gamestate = args.gamestate or 'Level1-1'
|
||||
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE)
|
||||
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
|
||||
use_restricted_actions=retro.Actions.DISCRETE)
|
||||
env.seed(args.seed)
|
||||
env = bench.Monitor(env, logger.get_dir())
|
||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||
@@ -151,6 +150,7 @@ def get_env_type(env_id):
|
||||
|
||||
return env_type, env_id
|
||||
|
||||
|
||||
def get_default_network(env_type):
|
||||
if env_type == 'mujoco' or env_type == 'classic_control':
|
||||
return 'mlp'
|
||||
@@ -159,6 +159,7 @@ def get_default_network(env_type):
|
||||
|
||||
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||
|
||||
|
||||
def get_alg_module(alg, submodule=None):
|
||||
submodule = submodule or alg
|
||||
try:
|
||||
@@ -174,6 +175,7 @@ def get_alg_module(alg, submodule=None):
|
||||
def get_learn_function(alg):
|
||||
return get_alg_module(alg).learn
|
||||
|
||||
|
||||
def get_learn_function_defaults(alg, env_type):
|
||||
try:
|
||||
alg_defaults = get_alg_module(alg, 'defaults')
|
||||
@@ -182,6 +184,7 @@ def get_learn_function_defaults(alg, env_type):
|
||||
kwargs = {}
|
||||
return kwargs
|
||||
|
||||
|
||||
def parse(v):
|
||||
'''
|
||||
convert value of a command-line arg to a python object if possible, othewise, keep as string
|
||||
@@ -201,7 +204,6 @@ def main():
|
||||
args, unknown_args = arg_parser.parse_known_args()
|
||||
extra_args = {k: parse(v) for k, v in parse_unknown_args(unknown_args).items()}
|
||||
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
rank = 0
|
||||
logger.configure()
|
||||
@@ -215,7 +217,6 @@ def main():
|
||||
save_path = osp.expanduser(args.save_path)
|
||||
model.save(save_path)
|
||||
|
||||
|
||||
if args.play:
|
||||
logger.log("Running trained model")
|
||||
env = build_env(args)
|
||||
@@ -230,6 +231,5 @@ def main():
|
||||
obs = env.reset()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Reference in New Issue
Block a user