2018-01-25 18:33:48 -08:00
"""
Helpers for scripts like run_atari . py .
"""
import os
2018-08-13 09:56:44 -07:00
try :
from mpi4py import MPI
except ImportError :
MPI = None
2018-01-25 18:33:48 -08:00
import gym
2018-02-26 17:40:16 +01:00
from gym . wrappers import FlattenDictWrapper
2018-01-25 18:33:48 -08:00
from baselines import logger
from baselines . bench import Monitor
from baselines . common import set_global_seeds
from baselines . common . atari_wrappers import make_atari , wrap_deepmind
from baselines . common . vec_env . subproc_vec_env import SubprocVecEnv
2018-08-29 01:48:56 +01:00
from baselines . common . vec_env . dummy_vec_env import DummyVecEnv
2018-10-22 18:36:39 -07:00
from baselines . common import retro_wrappers
2018-01-25 18:33:48 -08:00
2018-12-19 14:44:08 -08:00
def make_vec_env ( env_id , env_type , num_env , seed ,
wrapper_kwargs = None ,
start_index = 0 ,
reward_scale = 1.0 ,
flatten_dict_observations = True ,
gamestate = None ) :
2018-01-25 18:33:48 -08:00
"""
2018-08-29 01:48:56 +01:00
Create a wrapped , monitored SubprocVecEnv for Atari and MuJoCo .
2018-01-25 18:33:48 -08:00
"""
2018-12-19 14:44:08 -08:00
wrapper_kwargs = wrapper_kwargs or { }
2018-08-13 09:56:44 -07:00
mpi_rank = MPI . COMM_WORLD . Get_rank ( ) if MPI else 0
2018-10-22 18:36:39 -07:00
seed = seed + 10000 * mpi_rank if seed is not None else None
2019-01-07 11:07:19 -08:00
logger_dir = logger . get_dir ( )
2018-10-22 18:36:39 -07:00
def make_thunk ( rank ) :
return lambda : make_env (
env_id = env_id ,
env_type = env_type ,
2019-01-15 09:59:27 -08:00
mpi_rank = mpi_rank ,
subrank = rank ,
2018-10-22 18:36:39 -07:00
seed = seed ,
reward_scale = reward_scale ,
2018-10-24 02:00:09 +09:00
gamestate = gamestate ,
2018-12-19 14:44:08 -08:00
flatten_dict_observations = flatten_dict_observations ,
2019-01-07 11:07:19 -08:00
wrapper_kwargs = wrapper_kwargs ,
logger_dir = logger_dir
2018-10-22 18:36:39 -07:00
)
2018-01-25 18:33:48 -08:00
set_global_seeds ( seed )
2018-10-22 18:36:39 -07:00
if num_env > 1 :
return SubprocVecEnv ( [ make_thunk ( i + start_index ) for i in range ( num_env ) ] )
else :
return DummyVecEnv ( [ make_thunk ( start_index ) ] )
2019-01-15 09:59:27 -08:00
def make_env ( env_id , env_type , mpi_rank = 0 , subrank = 0 , seed = None , reward_scale = 1.0 , gamestate = None , flatten_dict_observations = True , wrapper_kwargs = None , logger_dir = None ) :
2018-12-19 14:44:08 -08:00
wrapper_kwargs = wrapper_kwargs or { }
2018-10-22 18:36:39 -07:00
if env_type == ' atari ' :
env = make_atari ( env_id )
elif env_type == ' retro ' :
import retro
gamestate = gamestate or retro . State . DEFAULT
env = retro_wrappers . make_retro ( game = env_id , max_episode_steps = 10000 , use_restricted_actions = retro . Actions . DISCRETE , state = gamestate )
else :
env = gym . make ( env_id )
2018-12-19 14:44:08 -08:00
if flatten_dict_observations and isinstance ( env . observation_space , gym . spaces . Dict ) :
keys = env . observation_space . spaces . keys ( )
env = gym . wrappers . FlattenDictWrapper ( env , dict_keys = list ( keys ) )
2018-10-22 18:36:39 -07:00
env . seed ( seed + subrank if seed is not None else None )
env = Monitor ( env ,
2019-01-07 11:07:19 -08:00
logger_dir and os . path . join ( logger_dir , str ( mpi_rank ) + ' . ' + str ( subrank ) ) ,
2018-10-22 18:36:39 -07:00
allow_early_resets = True )
if env_type == ' atari ' :
2018-10-31 01:16:15 +08:00
env = wrap_deepmind ( env , * * wrapper_kwargs )
elif env_type == ' retro ' :
env = retro_wrappers . wrap_deepmind_retro ( env , * * wrapper_kwargs )
if reward_scale != 1 :
env = retro_wrappers . RewardScaler ( env , reward_scale )
2018-10-22 18:36:39 -07:00
2018-10-31 01:16:15 +08:00
return env
2018-10-22 18:36:39 -07:00
2018-01-25 18:33:48 -08:00
2018-08-13 09:56:44 -07:00
def make_mujoco_env ( env_id , seed , reward_scale = 1.0 ) :
2018-01-25 18:33:48 -08:00
"""
Create a wrapped , monitored gym . Env for MuJoCo .
"""
2018-06-06 11:39:13 -07:00
rank = MPI . COMM_WORLD . Get_rank ( )
2018-08-13 09:56:44 -07:00
myseed = seed + 1000 * rank if seed is not None else None
set_global_seeds ( myseed )
2018-01-25 18:33:48 -08:00
env = gym . make ( env_id )
2018-08-17 09:40:35 -07:00
logger_path = None if logger . get_dir ( ) is None else os . path . join ( logger . get_dir ( ) , str ( rank ) )
env = Monitor ( env , logger_path , allow_early_resets = True )
2018-01-25 18:33:48 -08:00
env . seed ( seed )
2018-08-13 09:56:44 -07:00
if reward_scale != 1.0 :
from baselines . common . retro_wrappers import RewardScaler
env = RewardScaler ( env , reward_scale )
2018-01-25 18:33:48 -08:00
return env
2018-02-26 17:40:16 +01:00
def make_robotics_env ( env_id , seed , rank = 0 ) :
"""
Create a wrapped , monitored gym . Env for MuJoCo .
"""
set_global_seeds ( seed )
env = gym . make ( env_id )
env = FlattenDictWrapper ( env , [ ' observation ' , ' desired_goal ' ] )
env = Monitor (
env , logger . get_dir ( ) and os . path . join ( logger . get_dir ( ) , str ( rank ) ) ,
info_keywords = ( ' is_success ' , ) )
env . seed ( seed )
return env
2018-01-25 18:33:48 -08:00
def arg_parser ( ) :
"""
Create an empty argparse . ArgumentParser .
"""
import argparse
return argparse . ArgumentParser ( formatter_class = argparse . ArgumentDefaultsHelpFormatter )
def atari_arg_parser ( ) :
"""
Create an argparse . ArgumentParser for run_atari . py .
"""
2018-08-13 09:56:44 -07:00
print ( ' Obsolete - use common_arg_parser instead ' )
return common_arg_parser ( )
2018-01-25 18:33:48 -08:00
def mujoco_arg_parser ( ) :
2018-08-13 09:56:44 -07:00
print ( ' Obsolete - use common_arg_parser instead ' )
return common_arg_parser ( )
def common_arg_parser ( ) :
2018-01-25 18:33:48 -08:00
"""
Create an argparse . ArgumentParser for run_mujoco . py .
"""
parser = arg_parser ( )
2018-02-26 17:40:16 +01:00
parser . add_argument ( ' --env ' , help = ' environment ID ' , type = str , default = ' Reacher-v2 ' )
2019-02-06 17:06:11 -08:00
parser . add_argument ( ' --env_type ' , help = ' type of environment, used when the environment type cannot be automatically determined ' , type = str )
2018-08-13 09:56:44 -07:00
parser . add_argument ( ' --seed ' , help = ' RNG seed ' , type = int , default = None )
parser . add_argument ( ' --alg ' , help = ' Algorithm ' , type = str , default = ' ppo2 ' )
2018-08-29 01:48:56 +01:00
parser . add_argument ( ' --num_timesteps ' , type = float , default = 1e6 ) ,
2018-08-13 09:56:44 -07:00
parser . add_argument ( ' --network ' , help = ' network type (mlp, cnn, lstm, cnn_lstm, conv_only) ' , default = None )
parser . add_argument ( ' --gamestate ' , help = ' game state to load (so far only used in retro games) ' , default = None )
parser . add_argument ( ' --num_env ' , help = ' Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco ' , default = None , type = int )
parser . add_argument ( ' --reward_scale ' , help = ' Reward scale factor. Default: 1.0 ' , default = 1.0 , type = float )
parser . add_argument ( ' --save_path ' , help = ' Path to save trained model to ' , default = None , type = str )
2018-11-06 07:32:17 +09:00
parser . add_argument ( ' --save_video_interval ' , help = ' Save video every x steps (0 = disabled) ' , default = 0 , type = int )
parser . add_argument ( ' --save_video_length ' , help = ' Length of recorded video. Default: 200 ' , default = 200 , type = int )
2018-05-21 15:24:00 -07:00
parser . add_argument ( ' --play ' , default = False , action = ' store_true ' )
2019-01-03 11:33:31 -08:00
parser . add_argument ( ' --extra_import ' , help = ' Extra module to import to access external environments ' , type = str , default = None )
2018-02-26 17:40:16 +01:00
return parser
def robotics_arg_parser ( ) :
"""
Create an argparse . ArgumentParser for run_mujoco . py .
"""
parser = arg_parser ( )
parser . add_argument ( ' --env ' , help = ' environment ID ' , type = str , default = ' FetchReach-v0 ' )
2018-08-13 09:56:44 -07:00
parser . add_argument ( ' --seed ' , help = ' RNG seed ' , type = int , default = None )
2018-01-25 18:33:48 -08:00
parser . add_argument ( ' --num-timesteps ' , type = int , default = int ( 1e6 ) )
return parser
2018-08-13 09:56:44 -07:00
def parse_unknown_args ( args ) :
"""
Parse arguments not consumed by arg parser into a dicitonary
"""
retval = { }
2018-09-12 10:14:41 -07:00
preceded_by_key = False
2018-08-13 09:56:44 -07:00
for arg in args :
2018-09-12 10:14:41 -07:00
if arg . startswith ( ' -- ' ) :
if ' = ' in arg :
key = arg . split ( ' = ' ) [ 0 ] [ 2 : ]
value = arg . split ( ' = ' ) [ 1 ]
retval [ key ] = value
else :
key = arg [ 2 : ]
preceded_by_key = True
elif preceded_by_key :
retval [ key ] = arg
preceded_by_key = False
2018-08-13 09:56:44 -07:00
return retval