fixes to readme and baselines/run.py (#80)
* fixes to readme and baselines/run.py * polish installation section of baselines README * polish installation section of baselines README
This commit is contained in:
@@ -13,7 +13,6 @@ from baselines import bench, logger
|
||||
from importlib import import_module
|
||||
|
||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common import atari_wrappers, retro_wrappers
|
||||
|
||||
try:
|
||||
@@ -92,19 +91,8 @@ def build_env(args):
|
||||
seed = args.seed
|
||||
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
if env_type == 'mujoco':
|
||||
get_session(tf.ConfigProto(allow_soft_placement=True,
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=1))
|
||||
|
||||
if args.num_env:
|
||||
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
|
||||
else:
|
||||
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
||||
|
||||
env = VecNormalize(env)
|
||||
|
||||
elif env_type == 'atari':
|
||||
if env_type == 'atari':
|
||||
if alg == 'acer':
|
||||
env = make_vec_env(env_id, env_type, nenv, seed)
|
||||
elif alg == 'deepq':
|
||||
@@ -132,17 +120,15 @@ def build_env(args):
|
||||
env = bench.Monitor(env, logger.get_dir())
|
||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||
|
||||
elif env_type == 'classic_control':
|
||||
def make_env():
|
||||
e = gym.make(env_id)
|
||||
e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
|
||||
e.seed(seed)
|
||||
return e
|
||||
else:
|
||||
get_session(tf.ConfigProto(allow_soft_placement=True,
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=1))
|
||||
|
||||
env = DummyVecEnv([make_env])
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
|
||||
|
||||
else:
|
||||
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||
if env_type == 'mujoco':
|
||||
env = VecNormalize(env)
|
||||
|
||||
return env
|
||||
|
||||
@@ -163,10 +149,10 @@ def get_env_type(env_id):
|
||||
|
||||
|
||||
def get_default_network(env_type):
|
||||
if env_type == 'mujoco' or env_type == 'classic_control':
|
||||
return 'mlp'
|
||||
if env_type == 'atari':
|
||||
return 'cnn'
|
||||
else:
|
||||
return 'mlp'
|
||||
|
||||
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||
|
||||
|
Reference in New Issue
Block a user