Fix result_plotters in vectorized mujoco environments (#533)
* I investigated a bit about running a training in a vectorized monitored mujoco env and found out that the 0.monitor.csv file could not be plotted using baselines.results_plotter.py functions. Moreover the seed is the same in every parallel environments due to the particular behaviour of lambda. this fixes both issues without breaking the function in other files (baselines.acktr.run_mujoco still works) * unifies make_atari_env and make_mujoco_env * redefine make_mujoco_env because of run_mujoco in acktr not compatible with DummyVecEnv and SubprocVecEnv * fix if else * Update run.py
This commit is contained in:
@@ -15,22 +15,28 @@ from baselines.bench import Monitor
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common.retro_wrappers import RewardScaler
|
||||
|
||||
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
|
||||
|
||||
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0):
|
||||
"""
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari.
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||
"""
|
||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
def make_env(rank): # pylint: disable=C0111
|
||||
def _thunk():
|
||||
env = make_atari(env_id)
|
||||
env = make_atari(env_id) if env_type == 'atari' else gym.make(env_id)
|
||||
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
|
||||
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
|
||||
return wrap_deepmind(env, **wrapper_kwargs)
|
||||
if env_type == 'atari': return wrap_deepmind(env, **wrapper_kwargs)
|
||||
elif reward_scale != 1: return RewardScaler(env, reward_scale)
|
||||
else: return env
|
||||
return _thunk
|
||||
set_global_seeds(seed)
|
||||
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
||||
if num_env > 1: return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
||||
else: return DummyVecEnv([make_env(start_index)])
|
||||
|
||||
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
||||
"""
|
||||
@@ -43,11 +49,9 @@ def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
||||
logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
|
||||
env = Monitor(env, logger_path, allow_early_resets=True)
|
||||
env.seed(seed)
|
||||
|
||||
if reward_scale != 1.0:
|
||||
from baselines.common.retro_wrappers import RewardScaler
|
||||
env = RewardScaler(env, reward_scale)
|
||||
|
||||
return env
|
||||
|
||||
def make_robotics_env(env_id, seed, rank=0):
|
||||
@@ -122,6 +126,3 @@ def parse_unknown_args(args):
|
||||
retval[key] = value
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
|
||||
|
@@ -7,14 +7,13 @@ import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
|
||||
from baselines.common.tf_util import get_session
|
||||
from baselines import bench, logger
|
||||
from importlib import import_module
|
||||
|
||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common import atari_wrappers, retro_wrappers
|
||||
|
||||
try:
|
||||
@@ -77,7 +76,7 @@ def train(args, extra_args):
|
||||
|
||||
def build_env(args):
|
||||
ncpu = multiprocessing.cpu_count()
|
||||
if sys.platform == 'darwin': ncpu //= 2
|
||||
if sys.platform == 'darwin': ncpu /= 2
|
||||
nenv = args.num_env or ncpu
|
||||
alg = args.alg
|
||||
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
@@ -90,15 +89,15 @@ def build_env(args):
|
||||
inter_op_parallelism_threads=1))
|
||||
|
||||
if args.num_env:
|
||||
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
|
||||
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
|
||||
else:
|
||||
env = DummyVecEnv([lambda: make_mujoco_env(env_id, seed, args.reward_scale)])
|
||||
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
||||
|
||||
env = VecNormalize(env)
|
||||
|
||||
elif env_type == 'atari':
|
||||
if alg == 'acer':
|
||||
env = make_atari_env(env_id, nenv, seed)
|
||||
env = make_vec_env(env_id, env_type, nenv, seed)
|
||||
elif alg == 'deepq':
|
||||
env = atari_wrappers.make_atari(env_id)
|
||||
env.seed(seed)
|
||||
@@ -113,7 +112,7 @@ def build_env(args):
|
||||
env.seed(seed)
|
||||
else:
|
||||
frame_stack_size = 4
|
||||
env = VecFrameStack(make_atari_env(env_id, nenv, seed), frame_stack_size)
|
||||
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
|
||||
|
||||
elif env_type == 'retro':
|
||||
import retro
|
||||
|
Reference in New Issue
Block a user