wrap retro envs correctly for other (non-deepq) algorithms (#669)
* wrap retro envs correctly for other (non-deepq) algorithms * flake and csh comments * flake and csh comments
This commit is contained in:
@@ -16,30 +16,56 @@ from baselines.common import set_global_seeds
|
|||||||
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
||||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
from baselines.common.retro_wrappers import RewardScaler
|
from baselines.common import retro_wrappers
|
||||||
|
|
||||||
|
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None):
|
||||||
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0):
|
|
||||||
"""
|
"""
|
||||||
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||||
"""
|
"""
|
||||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||||
def make_env(rank): # pylint: disable=C0111
|
seed = seed + 10000 * mpi_rank if seed is not None else None
|
||||||
def _thunk():
|
def make_thunk(rank):
|
||||||
env = make_atari(env_id) if env_type == 'atari' else gym.make(env_id)
|
return lambda: make_env(
|
||||||
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
|
env_id=env_id,
|
||||||
env = Monitor(env,
|
env_type=env_type,
|
||||||
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)),
|
subrank = rank,
|
||||||
allow_early_resets=True)
|
seed=seed,
|
||||||
|
reward_scale=reward_scale,
|
||||||
|
gamestate=gamestate
|
||||||
|
)
|
||||||
|
|
||||||
if env_type == 'atari': return wrap_deepmind(env, **wrapper_kwargs)
|
|
||||||
elif reward_scale != 1: return RewardScaler(env, reward_scale)
|
|
||||||
else: return env
|
|
||||||
return _thunk
|
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
if num_env > 1: return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
if num_env > 1:
|
||||||
else: return DummyVecEnv([make_env(start_index)])
|
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
|
||||||
|
else:
|
||||||
|
return DummyVecEnv([make_thunk(start_index)])
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs=None):
|
||||||
|
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||||
|
if env_type == 'atari':
|
||||||
|
env = make_atari(env_id)
|
||||||
|
elif env_type == 'retro':
|
||||||
|
import retro
|
||||||
|
gamestate = gamestate or retro.State.DEFAULT
|
||||||
|
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
|
||||||
|
else:
|
||||||
|
env = gym.make(env_id)
|
||||||
|
|
||||||
|
env.seed(seed + subrank if seed is not None else None)
|
||||||
|
env = Monitor(env,
|
||||||
|
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||||
|
allow_early_resets=True)
|
||||||
|
|
||||||
|
if env_type == 'atari':
|
||||||
|
return wrap_deepmind(env, **wrapper_kwargs)
|
||||||
|
elif reward_scale != 1:
|
||||||
|
return retro_wrappers.RewardScaler(env, reward_scale)
|
||||||
|
else:
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
||||||
"""
|
"""
|
||||||
|
@@ -7,13 +7,12 @@ import tensorflow as tf
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
|
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||||
from baselines.common.tf_util import get_session
|
from baselines.common.tf_util import get_session
|
||||||
from baselines import bench, logger
|
from baselines import logger
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||||
from baselines.common import atari_wrappers, retro_wrappers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
@@ -87,38 +86,21 @@ def build_env(args):
|
|||||||
if sys.platform == 'darwin': ncpu //= 2
|
if sys.platform == 'darwin': ncpu //= 2
|
||||||
nenv = args.num_env or ncpu
|
nenv = args.num_env or ncpu
|
||||||
alg = args.alg
|
alg = args.alg
|
||||||
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
|
|
||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args.env)
|
||||||
|
|
||||||
if env_type == 'atari':
|
if env_type in {'atari', 'retro'}:
|
||||||
if alg == 'acer':
|
if alg == 'acer':
|
||||||
env = make_vec_env(env_id, env_type, nenv, seed)
|
env = make_vec_env(env_id, env_type, nenv, seed)
|
||||||
elif alg == 'deepq':
|
elif alg == 'deepq':
|
||||||
env = atari_wrappers.make_atari(env_id)
|
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
|
||||||
env.seed(seed)
|
|
||||||
env = bench.Monitor(env, logger.get_dir())
|
|
||||||
env = atari_wrappers.wrap_deepmind(env, frame_stack=True)
|
|
||||||
elif alg == 'trpo_mpi':
|
elif alg == 'trpo_mpi':
|
||||||
env = atari_wrappers.make_atari(env_id)
|
env = make_env(env_id, env_type, seed=seed)
|
||||||
env.seed(seed)
|
|
||||||
env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
|
|
||||||
env = atari_wrappers.wrap_deepmind(env)
|
|
||||||
# TODO check if the second seeding is necessary, and eventually remove
|
|
||||||
env.seed(seed)
|
|
||||||
else:
|
else:
|
||||||
frame_stack_size = 4
|
frame_stack_size = 4
|
||||||
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
|
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
|
||||||
|
env = VecFrameStack(env, frame_stack_size)
|
||||||
elif env_type == 'retro':
|
|
||||||
import retro
|
|
||||||
gamestate = args.gamestate or retro.State.DEFAULT
|
|
||||||
env = retro_wrappers.make_retro(game=args.env, state=gamestate, max_episode_steps=10000,
|
|
||||||
use_restricted_actions=retro.Actions.DISCRETE)
|
|
||||||
env.seed(args.seed)
|
|
||||||
env = bench.Monitor(env, logger.get_dir())
|
|
||||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
config = tf.ConfigProto(allow_soft_placement=True,
|
config = tf.ConfigProto(allow_soft_placement=True,
|
||||||
|
Reference in New Issue
Block a user