Fix result_plotters in vectorized mujoco environments (#533)
* I investigated a bit about running a training in a vectorized monitored mujoco env and found out that the 0.monitor.csv file could not be plotted using baselines.results_plotter.py functions. Moreover the seed is the same in every parallel environments due to the particular behaviour of lambda. this fixes both issues without breaking the function in other files (baselines.acktr.run_mujoco still works) * unifies make_atari_env and make_mujoco_env * redefine make_mujoco_env because of run_mujoco in acktr not compatible with DummyVecEnv and SubprocVecEnv * fix if else * Update run.py
This commit is contained in:
@@ -15,22 +15,28 @@ from baselines.bench import Monitor
|
|||||||
from baselines.common import set_global_seeds
|
from baselines.common import set_global_seeds
|
||||||
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
||||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||||
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
|
from baselines.common.retro_wrappers import RewardScaler
|
||||||
|
|
||||||
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
|
|
||||||
|
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0):
|
||||||
"""
|
"""
|
||||||
Create a wrapped, monitored SubprocVecEnv for Atari.
|
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||||
"""
|
"""
|
||||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||||
def make_env(rank): # pylint: disable=C0111
|
def make_env(rank): # pylint: disable=C0111
|
||||||
def _thunk():
|
def _thunk():
|
||||||
env = make_atari(env_id)
|
env = make_atari(env_id) if env_type == 'atari' else gym.make(env_id)
|
||||||
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
|
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
|
||||||
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
|
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
|
||||||
return wrap_deepmind(env, **wrapper_kwargs)
|
if env_type == 'atari': return wrap_deepmind(env, **wrapper_kwargs)
|
||||||
|
elif reward_scale != 1: return RewardScaler(env, reward_scale)
|
||||||
|
else: return env
|
||||||
return _thunk
|
return _thunk
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
if num_env > 1: return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
||||||
|
else: return DummyVecEnv([make_env(start_index)])
|
||||||
|
|
||||||
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
||||||
"""
|
"""
|
||||||
@@ -43,11 +49,9 @@ def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
|||||||
logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
|
logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
|
||||||
env = Monitor(env, logger_path, allow_early_resets=True)
|
env = Monitor(env, logger_path, allow_early_resets=True)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
|
|
||||||
if reward_scale != 1.0:
|
if reward_scale != 1.0:
|
||||||
from baselines.common.retro_wrappers import RewardScaler
|
from baselines.common.retro_wrappers import RewardScaler
|
||||||
env = RewardScaler(env, reward_scale)
|
env = RewardScaler(env, reward_scale)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def make_robotics_env(env_id, seed, rank=0):
|
def make_robotics_env(env_id, seed, rank=0):
|
||||||
@@ -122,6 +126,3 @@ def parse_unknown_args(args):
|
|||||||
retval[key] = value
|
retval[key] = value
|
||||||
|
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -7,14 +7,13 @@ import tensorflow as tf
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
|
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
|
||||||
from baselines.common.tf_util import get_session
|
from baselines.common.tf_util import get_session
|
||||||
from baselines import bench, logger
|
from baselines import bench, logger
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|
||||||
from baselines.common import atari_wrappers, retro_wrappers
|
from baselines.common import atari_wrappers, retro_wrappers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -77,7 +76,7 @@ def train(args, extra_args):
|
|||||||
|
|
||||||
def build_env(args):
|
def build_env(args):
|
||||||
ncpu = multiprocessing.cpu_count()
|
ncpu = multiprocessing.cpu_count()
|
||||||
if sys.platform == 'darwin': ncpu //= 2
|
if sys.platform == 'darwin': ncpu /= 2
|
||||||
nenv = args.num_env or ncpu
|
nenv = args.num_env or ncpu
|
||||||
alg = args.alg
|
alg = args.alg
|
||||||
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||||
@@ -90,15 +89,15 @@ def build_env(args):
|
|||||||
inter_op_parallelism_threads=1))
|
inter_op_parallelism_threads=1))
|
||||||
|
|
||||||
if args.num_env:
|
if args.num_env:
|
||||||
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
|
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
|
||||||
else:
|
else:
|
||||||
env = DummyVecEnv([lambda: make_mujoco_env(env_id, seed, args.reward_scale)])
|
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
||||||
|
|
||||||
env = VecNormalize(env)
|
env = VecNormalize(env)
|
||||||
|
|
||||||
elif env_type == 'atari':
|
elif env_type == 'atari':
|
||||||
if alg == 'acer':
|
if alg == 'acer':
|
||||||
env = make_atari_env(env_id, nenv, seed)
|
env = make_vec_env(env_id, env_type, nenv, seed)
|
||||||
elif alg == 'deepq':
|
elif alg == 'deepq':
|
||||||
env = atari_wrappers.make_atari(env_id)
|
env = atari_wrappers.make_atari(env_id)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
@@ -113,7 +112,7 @@ def build_env(args):
|
|||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
else:
|
else:
|
||||||
frame_stack_size = 4
|
frame_stack_size = 4
|
||||||
env = VecFrameStack(make_atari_env(env_id, nenv, seed), frame_stack_size)
|
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
|
||||||
|
|
||||||
elif env_type == 'retro':
|
elif env_type == 'retro':
|
||||||
import retro
|
import retro
|
||||||
|
Reference in New Issue
Block a user