minor changes to baselines (#243)

* minor changes to baselines

* fix spaces reference

* remove flake8 disable comments and fix import

* okay maybe don't add spec to vec_env
This commit is contained in:
Christopher Hesse
2019-02-19 17:55:32 -08:00
committed by Peter Zhokhov
parent fb6fd51fe6
commit a4188f4b36
5 changed files with 37 additions and 25 deletions

View File

@@ -20,7 +20,7 @@ def register_benchmark(benchmark):
if 'tasks' in benchmark:
for t in benchmark['tasks']:
if 'desc' not in t:
t['desc'] = remove_version_re.sub('', t['env_id'])
t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id')))
_BENCHMARKS.append(benchmark)

View File

@@ -6,6 +6,8 @@ import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)
from .wrappers import TimeLimit
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
@@ -221,11 +223,13 @@ class LazyFrames(object):
def __getitem__(self, i):
return self._force()[i]
def make_atari(env_id):
def make_atari(env_id, max_episode_steps=None):
env = gym.make(env_id)
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps=max_episode_steps)
return env
def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):

View File

@@ -1,25 +1,11 @@
# flake8: noqa F403, F405
from .atari_wrappers import *
from collections import deque
import cv2
cv2.ocl.setUseOpenCL(False)
from .atari_wrappers import WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame
from .wrappers import TimeLimit
import numpy as np
import gym
class TimeLimit(gym.Wrapper):
def __init__(self, env, max_episode_steps=None):
super(TimeLimit, self).__init__(env)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
def step(self, ac):
observation, reward, done, info = self.env.step(ac)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
done = True
info['TimeLimit.truncated'] = True
return observation, reward, done, info
def reset(self, **kwargs):
self._elapsed_steps = 0
return self.env.reset(**kwargs)
class StochasticFrameSkip(gym.Wrapper):
def __init__(self, env, n, stickprob):
@@ -99,7 +85,7 @@ class Downsample(gym.ObservationWrapper):
gym.ObservationWrapper.__init__(self, env)
(oldh, oldw, oldc) = env.observation_space.shape
newshape = (oldh//ratio, oldw//ratio, oldc)
self.observation_space = spaces.Box(low=0, high=255,
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=newshape, dtype=np.uint8)
def observation(self, frame):
@@ -116,7 +102,7 @@ class Rgb2gray(gym.ObservationWrapper):
"""
gym.ObservationWrapper.__init__(self, env)
(oldh, oldw, _oldc) = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255,
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=(oldh, oldw, 1), dtype=np.uint8)
def observation(self, frame):
@@ -213,8 +199,10 @@ class StartDoingRandomActionsWrapper(gym.Wrapper):
self.some_random_steps()
return self.last_obs, rew, done, info
def make_retro(*, game, state, max_episode_steps, **kwargs):
def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs):
import retro
if state is None:
state = retro.State.DEFAULT
env = retro.make(game, state, **kwargs)
env = StochasticFrameSkip(env, n=4, stickprob=0.25)
if max_episode_steps is not None:

View File

@@ -1,4 +1,3 @@
import joblib
import numpy as np
import tensorflow as tf # pylint: ignore-module
import copy
@@ -336,6 +335,7 @@ def save_state(fname, sess=None):
# TODO: ensure there is no subtle differences and remove one
def save_variables(save_path, variables=None, sess=None):
import joblib
sess = sess or get_session()
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
@@ -347,6 +347,7 @@ def save_variables(save_path, variables=None, sess=None):
joblib.dump(save_dict, save_path)
def load_variables(load_path, variables=None, sess=None):
import joblib
sess = sess or get_session()
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

View File

@@ -0,0 +1,19 @@
import gym
class TimeLimit(gym.Wrapper):
def __init__(self, env, max_episode_steps=None):
super(TimeLimit, self).__init__(env)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
def step(self, ac):
observation, reward, done, info = self.env.step(ac)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
done = True
info['TimeLimit.truncated'] = True
return observation, reward, done, info
def reset(self, **kwargs):
self._elapsed_steps = 0
return self.env.reset(**kwargs)