Merge branch 'internal' of github.com:openai/baselines into internal
This commit is contained in:
@@ -92,7 +92,7 @@ class Model(object):
|
||||
self.initial_state = step_model.initial_state
|
||||
tf.global_variables_initializer().run(session=sess)
|
||||
|
||||
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20,
|
||||
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=100, nprocs=32, nsteps=20,
|
||||
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
|
||||
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):
|
||||
set_global_seeds(seed)
|
||||
|
@@ -11,7 +11,7 @@ KFAC_DEBUG = False
|
||||
|
||||
|
||||
class KfacOptimizer():
|
||||
|
||||
# note that KfacOptimizer will be truly synchronous (and thus deterministic) only if a single-threaded session is used
|
||||
def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, is_async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self._lr = learning_rate
|
||||
|
@@ -130,27 +130,60 @@ class ClipRewardEnv(gym.RewardWrapper):
|
||||
"""Bin reward to {+1, 0, -1} by its sign."""
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env, width=84, height=84, grayscale=True):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.grayscale = grayscale
|
||||
if self.grayscale:
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 1), dtype=np.uint8)
|
||||
else:
|
||||
self.observation_space = spaces.Box(low=0, high=255,
|
||||
shape=(self.height, self.width, 3), dtype=np.uint8)
|
||||
|
||||
def observation(self, frame):
|
||||
if self.grayscale:
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
|
||||
"""
|
||||
Warp frames to 84x84 as done in the Nature paper and later work.
|
||||
|
||||
If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
|
||||
observation should be warped.
|
||||
"""
|
||||
super().__init__(env)
|
||||
self._width = width
|
||||
self._height = height
|
||||
self._grayscale = grayscale
|
||||
self._key = dict_space_key
|
||||
if self._grayscale:
|
||||
num_colors = 1
|
||||
else:
|
||||
num_colors = 3
|
||||
|
||||
new_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self._height, self._width, num_colors),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
if self._key is None:
|
||||
original_space = self.observation_space
|
||||
self.observation_space = new_space
|
||||
else:
|
||||
original_space = self.observation_space.spaces[self._key]
|
||||
self.observation_space.spaces[self._key] = new_space
|
||||
assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
|
||||
|
||||
def observation(self, obs):
|
||||
if self._key is None:
|
||||
frame = obs
|
||||
else:
|
||||
frame = obs[self._key]
|
||||
|
||||
if self._grayscale:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||
if self.grayscale:
|
||||
frame = cv2.resize(
|
||||
frame, (self._width, self._height), interpolation=cv2.INTER_AREA
|
||||
)
|
||||
if self._grayscale:
|
||||
frame = np.expand_dims(frame, -1)
|
||||
return frame
|
||||
|
||||
if self._key is None:
|
||||
obs = frame
|
||||
else:
|
||||
obs = obs.copy()
|
||||
obs[self._key] = frame
|
||||
return obs
|
||||
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
|
@@ -17,9 +17,11 @@ from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common import retro_wrappers
|
||||
from baselines.common.wrappers import ClipActionsWrapper
|
||||
|
||||
def make_vec_env(env_id, env_type, num_env, seed,
|
||||
wrapper_kwargs=None,
|
||||
env_kwargs=None,
|
||||
start_index=0,
|
||||
reward_scale=1.0,
|
||||
flatten_dict_observations=True,
|
||||
@@ -28,6 +30,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||
"""
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
env_kwargs = env_kwargs or {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
seed = seed + 10000 * mpi_rank if seed is not None else None
|
||||
logger_dir = logger.get_dir()
|
||||
@@ -42,6 +45,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
gamestate=gamestate,
|
||||
flatten_dict_observations=flatten_dict_observations,
|
||||
wrapper_kwargs=wrapper_kwargs,
|
||||
env_kwargs=env_kwargs,
|
||||
logger_dir=logger_dir
|
||||
)
|
||||
|
||||
@@ -52,8 +56,15 @@ def make_vec_env(env_id, env_type, num_env, seed,
|
||||
return DummyVecEnv([make_thunk(start_index)])
|
||||
|
||||
|
||||
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
|
||||
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None):
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
env_kwargs = env_kwargs or {}
|
||||
if ':' in env_id:
|
||||
import re
|
||||
import importlib
|
||||
module_name = re.sub(':.*','',env_id)
|
||||
env_id = re.sub('.*:', '', env_id)
|
||||
importlib.import_module(module_name)
|
||||
if env_type == 'atari':
|
||||
env = make_atari(env_id)
|
||||
elif env_type == 'retro':
|
||||
@@ -61,7 +72,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
|
||||
gamestate = gamestate or retro.State.DEFAULT
|
||||
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
|
||||
else:
|
||||
env = gym.make(env_id)
|
||||
env = gym.make(env_id, **env_kwargs)
|
||||
|
||||
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
|
||||
keys = env.observation_space.spaces.keys()
|
||||
@@ -72,6 +83,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
|
||||
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
|
||||
|
||||
if env_type == 'atari':
|
||||
env = wrap_deepmind(env, **wrapper_kwargs)
|
||||
elif env_type == 'retro':
|
||||
@@ -79,6 +91,9 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
|
||||
wrapper_kwargs['frame_stack'] = 1
|
||||
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
|
||||
|
||||
if isinstance(env.action_space, gym.spaces.Box):
|
||||
env = ClipActionsWrapper(env)
|
||||
|
||||
if reward_scale != 1:
|
||||
env = retro_wrappers.RewardScaler(env, reward_scale)
|
||||
|
||||
|
@@ -90,6 +90,8 @@ def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_
|
||||
sum_y *= interstep_decay
|
||||
count_y *= interstep_decay
|
||||
while True:
|
||||
if luoi >= len(xolds):
|
||||
break
|
||||
xold = xolds[luoi]
|
||||
if xold <= xnew:
|
||||
decay = np.exp(- (xnew - xold) / decay_period)
|
||||
@@ -98,8 +100,6 @@ def one_sided_ema(xolds, yolds, low=None, high=None, n=512, decay_steps=1., low_
|
||||
luoi += 1
|
||||
else:
|
||||
break
|
||||
if luoi >= len(xolds):
|
||||
break
|
||||
sum_ys[i] = sum_y
|
||||
count_ys[i] = count_y
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from baselines.common import mpi_util
|
||||
from mpi4py import MPI
|
||||
from baselines import logger
|
||||
from baselines.common.tests.test_with_mpi import with_mpi
|
||||
from baselines.common import mpi_util
|
||||
|
||||
@with_mpi()
|
||||
def test_mpi_weighted_mean():
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
with logger.scoped_configure(comm=comm):
|
||||
if comm.rank == 0:
|
||||
@@ -13,7 +13,6 @@ def test_mpi_weighted_mean():
|
||||
name2valcount = {'a' : (19, 1), 'c' : (42,3)}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
|
||||
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
|
||||
if comm.rank == 0:
|
||||
@@ -24,4 +23,4 @@ def test_mpi_weighted_mean():
|
||||
logger.logkv_mean(name, val)
|
||||
d2 = logger.dumpkvs()
|
||||
if comm.rank == 0:
|
||||
assert d2 == correctval
|
||||
assert d2 == correctval
|
@@ -5,6 +5,12 @@ from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
N_TRIALS = 10000
|
||||
N_EPISODES = 100
|
||||
|
||||
_sess_config = tf.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=1
|
||||
)
|
||||
|
||||
def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||
def seeded_env_fn():
|
||||
env = env_fn()
|
||||
@@ -13,7 +19,7 @@ def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||
|
||||
np.random.seed(0)
|
||||
env = DummyVecEnv([seeded_env_fn])
|
||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||
with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default():
|
||||
tf.set_random_seed(0)
|
||||
model = learn_fn(env)
|
||||
sum_rew = 0
|
||||
@@ -34,7 +40,7 @@ def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
|
||||
|
||||
def reward_per_episode_test(env_fn, learn_fn, min_avg_reward, n_trials=N_EPISODES):
|
||||
env = DummyVecEnv([env_fn])
|
||||
with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto(allow_soft_placement=True)).as_default():
|
||||
with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default():
|
||||
model = learn_fn(env)
|
||||
N_TRIALS = 100
|
||||
observations, actions, rewards = rollout(env, model, N_TRIALS)
|
||||
|
@@ -145,8 +145,7 @@ class VecEnvWrapper(VecEnv):
|
||||
|
||||
def __init__(self, venv, observation_space=None, action_space=None):
|
||||
self.venv = venv
|
||||
VecEnv.__init__(self,
|
||||
num_envs=venv.num_envs,
|
||||
super().__init__(num_envs=venv.num_envs,
|
||||
observation_space=observation_space or venv.observation_space,
|
||||
action_space=action_space or venv.action_space)
|
||||
|
||||
@@ -170,6 +169,11 @@ class VecEnvWrapper(VecEnv):
|
||||
def get_images(self):
|
||||
return self.venv.get_images()
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name.startswith('_'):
|
||||
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
|
||||
return getattr(self.venv, name)
|
||||
|
||||
class VecEnvObservationWrapper(VecEnvWrapper):
|
||||
@abstractmethod
|
||||
def process(self, obs):
|
||||
|
@@ -5,16 +5,18 @@ import time
|
||||
from collections import deque
|
||||
|
||||
class VecMonitor(VecEnvWrapper):
|
||||
def __init__(self, venv, filename=None, keep_buf=0):
|
||||
def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.eprets = None
|
||||
self.eplens = None
|
||||
self.epcount = 0
|
||||
self.tstart = time.time()
|
||||
if filename:
|
||||
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart})
|
||||
self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart},
|
||||
extra_keys=info_keywords)
|
||||
else:
|
||||
self.results_writer = None
|
||||
self.info_keywords = info_keywords
|
||||
self.keep_buf = keep_buf
|
||||
if self.keep_buf:
|
||||
self.epret_buf = deque([], maxlen=keep_buf)
|
||||
@@ -30,11 +32,16 @@ class VecMonitor(VecEnvWrapper):
|
||||
obs, rews, dones, infos = self.venv.step_wait()
|
||||
self.eprets += rews
|
||||
self.eplens += 1
|
||||
newinfos = []
|
||||
for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)):
|
||||
info = info.copy()
|
||||
if done:
|
||||
|
||||
newinfos = infos[:]
|
||||
for i in range(len(dones)):
|
||||
if dones[i]:
|
||||
info = infos[i].copy()
|
||||
ret = self.eprets[i]
|
||||
eplen = self.eplens[i]
|
||||
epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
|
||||
for k in self.info_keywords:
|
||||
epinfo[k] = info[k]
|
||||
info['episode'] = epinfo
|
||||
if self.keep_buf:
|
||||
self.epret_buf.append(ret)
|
||||
@@ -44,6 +51,5 @@ class VecMonitor(VecEnvWrapper):
|
||||
self.eplens[i] = 0
|
||||
if self.results_writer:
|
||||
self.results_writer.write_row(epinfo)
|
||||
newinfos.append(info)
|
||||
|
||||
newinfos[i] = info
|
||||
return obs, rews, dones, newinfos
|
||||
|
@@ -1,8 +1,6 @@
|
||||
from . import VecEnvWrapper
|
||||
from baselines.common.running_mean_std import RunningMeanStd
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VecNormalize(VecEnvWrapper):
|
||||
"""
|
||||
A vectorized wrapper that normalizes the observations
|
||||
@@ -11,6 +9,7 @@ class VecNormalize(VecEnvWrapper):
|
||||
|
||||
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
from baselines.common.running_mean_std import RunningMeanStd
|
||||
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||
self.clipob = clipob
|
||||
|
@@ -1,6 +1,5 @@
|
||||
from .vec_env import VecEnvObservationWrapper
|
||||
|
||||
|
||||
class VecExtractDictObs(VecEnvObservationWrapper):
|
||||
def __init__(self, venv, key):
|
||||
self.key = key
|
||||
@@ -8,4 +7,4 @@ class VecExtractDictObs(VecEnvObservationWrapper):
|
||||
observation_space=venv.observation_space.spaces[self.key])
|
||||
|
||||
def process(self, obs):
|
||||
return obs[self.key]
|
||||
return obs[self.key]
|
||||
|
@@ -16,4 +16,14 @@ class TimeLimit(gym.Wrapper):
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._elapsed_steps = 0
|
||||
return self.env.reset(**kwargs)
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
class ClipActionsWrapper(gym.Wrapper):
|
||||
def step(self, action):
|
||||
import numpy as np
|
||||
action = np.nan_to_num(action)
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
return self.env.reset(**kwargs)
|
||||
|
@@ -217,7 +217,9 @@ def learn(network, env,
|
||||
stats = agent.get_stats()
|
||||
combined_stats = stats.copy()
|
||||
combined_stats['rollout/return'] = np.mean(epoch_episode_rewards)
|
||||
combined_stats['rollout/return_std'] = np.std(epoch_episode_rewards)
|
||||
combined_stats['rollout/return_history'] = np.mean(episode_rewards_history)
|
||||
combined_stats['rollout/return_history_std'] = np.std(episode_rewards_history)
|
||||
combined_stats['rollout/episode_steps'] = np.mean(epoch_episode_steps)
|
||||
combined_stats['rollout/actions_mean'] = np.mean(epoch_actions)
|
||||
combined_stats['rollout/Q_mean'] = np.mean(epoch_qs)
|
||||
|
@@ -361,7 +361,7 @@ class Logger(object):
|
||||
if isinstance(fmt, SeqWriter):
|
||||
fmt.writeseq(map(str, args))
|
||||
|
||||
def configure(dir=None, format_strs=None, comm=None):
|
||||
def configure(dir=None, format_strs=None, comm=None, log_suffix=''):
|
||||
"""
|
||||
If comm is provided, average all numerical stats across that comm
|
||||
"""
|
||||
@@ -373,7 +373,6 @@ def configure(dir=None, format_strs=None, comm=None):
|
||||
assert isinstance(dir, str)
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
log_suffix = ''
|
||||
rank = 0
|
||||
# check environment variables here instead of importing mpi4py
|
||||
# to avoid calling MPI_Init() when this module is imported
|
||||
@@ -381,7 +380,7 @@ def configure(dir=None, format_strs=None, comm=None):
|
||||
if varname in os.environ:
|
||||
rank = int(os.environ[varname])
|
||||
if rank > 0:
|
||||
log_suffix = "-rank%03i" % rank
|
||||
log_suffix = log_suffix + "-rank%03i" % rank
|
||||
|
||||
if format_strs is None:
|
||||
if rank == 0:
|
||||
|
Reference in New Issue
Block a user