wrap retro envs correctly for other (non-deepq) algorithms (#669)
* wrap retro envs correctly for other (non-deepq) algorithms * flake and csh comments * flake and csh comments
This commit is contained in:
@@ -7,13 +7,12 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||
from baselines.common.tf_util import get_session
|
||||
from baselines import bench, logger
|
||||
from baselines import logger
|
||||
from importlib import import_module
|
||||
|
||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
from baselines.common import atari_wrappers, retro_wrappers
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
@@ -87,38 +86,21 @@ def build_env(args):
|
||||
if sys.platform == 'darwin': ncpu //= 2
|
||||
nenv = args.num_env or ncpu
|
||||
alg = args.alg
|
||||
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
seed = args.seed
|
||||
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
|
||||
if env_type == 'atari':
|
||||
if env_type in {'atari', 'retro'}:
|
||||
if alg == 'acer':
|
||||
env = make_vec_env(env_id, env_type, nenv, seed)
|
||||
elif alg == 'deepq':
|
||||
env = atari_wrappers.make_atari(env_id)
|
||||
env.seed(seed)
|
||||
env = bench.Monitor(env, logger.get_dir())
|
||||
env = atari_wrappers.wrap_deepmind(env, frame_stack=True)
|
||||
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
|
||||
elif alg == 'trpo_mpi':
|
||||
env = atari_wrappers.make_atari(env_id)
|
||||
env.seed(seed)
|
||||
env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
|
||||
env = atari_wrappers.wrap_deepmind(env)
|
||||
# TODO check if the second seeding is necessary, and eventually remove
|
||||
env.seed(seed)
|
||||
env = make_env(env_id, env_type, seed=seed)
|
||||
else:
|
||||
frame_stack_size = 4
|
||||
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
|
||||
|
||||
elif env_type == 'retro':
|
||||
import retro
|
||||
gamestate = args.gamestate or retro.State.DEFAULT
|
||||
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
|
||||
use_restricted_actions=retro.Actions.DISCRETE)
|
||||
env.seed(args.seed)
|
||||
env = bench.Monitor(env, logger.get_dir())
|
||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
|
||||
env = VecFrameStack(env, frame_stack_size)
|
||||
|
||||
else:
|
||||
config = tf.ConfigProto(allow_soft_placement=True,
|
||||
|
Reference in New Issue
Block a user