From a4188f4b36a5b6f58e1f61abe1640438fb58dbee Mon Sep 17 00:00:00 2001 From: Christopher Hesse Date: Tue, 19 Feb 2019 17:55:32 -0800 Subject: [PATCH] minor changes to baselines (#243) * minor changes to baselines * fix spaces reference * remove flake8 disable comments and fix import * okay maybe don't add spec to vec_env --- baselines/bench/benchmarks.py | 2 +- baselines/common/atari_wrappers.py | 6 +++++- baselines/common/retro_wrappers.py | 32 ++++++++++-------------------- baselines/common/tf_util.py | 3 ++- baselines/common/wrappers.py | 19 ++++++++++++++++++ 5 files changed, 37 insertions(+), 25 deletions(-) create mode 100644 baselines/common/wrappers.py diff --git a/baselines/bench/benchmarks.py b/baselines/bench/benchmarks.py index 0d63e7a..c381935 100644 --- a/baselines/bench/benchmarks.py +++ b/baselines/bench/benchmarks.py @@ -20,7 +20,7 @@ def register_benchmark(benchmark): if 'tasks' in benchmark: for t in benchmark['tasks']: if 'desc' not in t: - t['desc'] = remove_version_re.sub('', t['env_id']) + t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id'))) _BENCHMARKS.append(benchmark) diff --git a/baselines/common/atari_wrappers.py b/baselines/common/atari_wrappers.py index 982843f..2c9b8c6 100644 --- a/baselines/common/atari_wrappers.py +++ b/baselines/common/atari_wrappers.py @@ -6,6 +6,8 @@ import gym from gym import spaces import cv2 cv2.ocl.setUseOpenCL(False) +from .wrappers import TimeLimit + class NoopResetEnv(gym.Wrapper): def __init__(self, env, noop_max=30): @@ -221,11 +223,13 @@ class LazyFrames(object): def __getitem__(self, i): return self._force()[i] -def make_atari(env_id): +def make_atari(env_id, max_episode_steps=None): env = gym.make(env_id) assert 'NoFrameskip' in env.spec.id env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) + if max_episode_steps is not None: + env = TimeLimit(env, max_episode_steps=max_episode_steps) return env def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): diff --git a/baselines/common/retro_wrappers.py b/baselines/common/retro_wrappers.py index 1e98044..2c42926 100644 --- a/baselines/common/retro_wrappers.py +++ b/baselines/common/retro_wrappers.py @@ -1,25 +1,11 @@ - # flake8: noqa F403, F405 -from .atari_wrappers import * +from collections import deque +import cv2 +cv2.ocl.setUseOpenCL(False) +from .atari_wrappers import WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame +from .wrappers import TimeLimit import numpy as np import gym -class TimeLimit(gym.Wrapper): - def __init__(self, env, max_episode_steps=None): - super(TimeLimit, self).__init__(env) - self._max_episode_steps = max_episode_steps - self._elapsed_steps = 0 - - def step(self, ac): - observation, reward, done, info = self.env.step(ac) - self._elapsed_steps += 1 - if self._elapsed_steps >= self._max_episode_steps: - done = True - info['TimeLimit.truncated'] = True - return observation, reward, done, info - - def reset(self, **kwargs): - self._elapsed_steps = 0 - return self.env.reset(**kwargs) class StochasticFrameSkip(gym.Wrapper): def __init__(self, env, n, stickprob): @@ -99,7 +85,7 @@ class Downsample(gym.ObservationWrapper): gym.ObservationWrapper.__init__(self, env) (oldh, oldw, oldc) = env.observation_space.shape newshape = (oldh//ratio, oldw//ratio, oldc) - self.observation_space = spaces.Box(low=0, high=255, + self.observation_space = gym.spaces.Box(low=0, high=255, shape=newshape, dtype=np.uint8) def observation(self, frame): @@ -116,7 +102,7 @@ class Rgb2gray(gym.ObservationWrapper): """ gym.ObservationWrapper.__init__(self, env) (oldh, oldw, _oldc) = env.observation_space.shape - self.observation_space = spaces.Box(low=0, high=255, + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(oldh, oldw, 1), dtype=np.uint8) def observation(self, frame): @@ -213,8 +199,10 @@ class StartDoingRandomActionsWrapper(gym.Wrapper): self.some_random_steps() return self.last_obs, rew, done, info -def make_retro(*, game, state, max_episode_steps, **kwargs): +def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs): import retro + if state is None: + state = retro.State.DEFAULT env = retro.make(game, state, **kwargs) env = StochasticFrameSkip(env, n=4, stickprob=0.25) if max_episode_steps is not None: diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 717b7dc..02cad43 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -1,4 +1,3 @@ -import joblib import numpy as np import tensorflow as tf # pylint: ignore-module import copy @@ -336,6 +335,7 @@ def save_state(fname, sess=None): # TODO: ensure there is no subtle differences and remove one def save_variables(save_path, variables=None, sess=None): + import joblib sess = sess or get_session() variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) @@ -347,6 +347,7 @@ def save_variables(save_path, variables=None, sess=None): joblib.dump(save_dict, save_path) def load_variables(load_path, variables=None, sess=None): + import joblib sess = sess or get_session() variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) diff --git a/baselines/common/wrappers.py b/baselines/common/wrappers.py new file mode 100644 index 0000000..7683d18 --- /dev/null +++ b/baselines/common/wrappers.py @@ -0,0 +1,19 @@ +import gym + +class TimeLimit(gym.Wrapper): + def __init__(self, env, max_episode_steps=None): + super(TimeLimit, self).__init__(env) + self._max_episode_steps = max_episode_steps + self._elapsed_steps = 0 + + def step(self, ac): + observation, reward, done, info = self.env.step(ac) + self._elapsed_steps += 1 + if self._elapsed_steps >= self._max_episode_steps: + done = True + info['TimeLimit.truncated'] = True + return observation, reward, done, info + + def reset(self, **kwargs): + self._elapsed_steps = 0 + return self.env.reset(**kwargs) \ No newline at end of file