minor changes to baselines (#243)
* minor changes to baselines * fix spaces reference * remove flake8 disable comments and fix import * okay maybe don't add spec to vec_env
This commit is contained in:
committed by
Peter Zhokhov
parent
fb6fd51fe6
commit
a4188f4b36
@@ -20,7 +20,7 @@ def register_benchmark(benchmark):
|
|||||||
if 'tasks' in benchmark:
|
if 'tasks' in benchmark:
|
||||||
for t in benchmark['tasks']:
|
for t in benchmark['tasks']:
|
||||||
if 'desc' not in t:
|
if 'desc' not in t:
|
||||||
t['desc'] = remove_version_re.sub('', t['env_id'])
|
t['desc'] = remove_version_re.sub('', t.get('env_id', t.get('id')))
|
||||||
_BENCHMARKS.append(benchmark)
|
_BENCHMARKS.append(benchmark)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -6,6 +6,8 @@ import gym
|
|||||||
from gym import spaces
|
from gym import spaces
|
||||||
import cv2
|
import cv2
|
||||||
cv2.ocl.setUseOpenCL(False)
|
cv2.ocl.setUseOpenCL(False)
|
||||||
|
from .wrappers import TimeLimit
|
||||||
|
|
||||||
|
|
||||||
class NoopResetEnv(gym.Wrapper):
|
class NoopResetEnv(gym.Wrapper):
|
||||||
def __init__(self, env, noop_max=30):
|
def __init__(self, env, noop_max=30):
|
||||||
@@ -221,11 +223,13 @@ class LazyFrames(object):
|
|||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
return self._force()[i]
|
return self._force()[i]
|
||||||
|
|
||||||
def make_atari(env_id):
|
def make_atari(env_id, max_episode_steps=None):
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
assert 'NoFrameskip' in env.spec.id
|
assert 'NoFrameskip' in env.spec.id
|
||||||
env = NoopResetEnv(env, noop_max=30)
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
env = MaxAndSkipEnv(env, skip=4)
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
|
if max_episode_steps is not None:
|
||||||
|
env = TimeLimit(env, max_episode_steps=max_episode_steps)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
|
def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
|
||||||
|
@@ -1,25 +1,11 @@
|
|||||||
# flake8: noqa F403, F405
|
from collections import deque
|
||||||
from .atari_wrappers import *
|
import cv2
|
||||||
|
cv2.ocl.setUseOpenCL(False)
|
||||||
|
from .atari_wrappers import WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame
|
||||||
|
from .wrappers import TimeLimit
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
class TimeLimit(gym.Wrapper):
|
|
||||||
def __init__(self, env, max_episode_steps=None):
|
|
||||||
super(TimeLimit, self).__init__(env)
|
|
||||||
self._max_episode_steps = max_episode_steps
|
|
||||||
self._elapsed_steps = 0
|
|
||||||
|
|
||||||
def step(self, ac):
|
|
||||||
observation, reward, done, info = self.env.step(ac)
|
|
||||||
self._elapsed_steps += 1
|
|
||||||
if self._elapsed_steps >= self._max_episode_steps:
|
|
||||||
done = True
|
|
||||||
info['TimeLimit.truncated'] = True
|
|
||||||
return observation, reward, done, info
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
|
||||||
self._elapsed_steps = 0
|
|
||||||
return self.env.reset(**kwargs)
|
|
||||||
|
|
||||||
class StochasticFrameSkip(gym.Wrapper):
|
class StochasticFrameSkip(gym.Wrapper):
|
||||||
def __init__(self, env, n, stickprob):
|
def __init__(self, env, n, stickprob):
|
||||||
@@ -99,7 +85,7 @@ class Downsample(gym.ObservationWrapper):
|
|||||||
gym.ObservationWrapper.__init__(self, env)
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
(oldh, oldw, oldc) = env.observation_space.shape
|
(oldh, oldw, oldc) = env.observation_space.shape
|
||||||
newshape = (oldh//ratio, oldw//ratio, oldc)
|
newshape = (oldh//ratio, oldw//ratio, oldc)
|
||||||
self.observation_space = spaces.Box(low=0, high=255,
|
self.observation_space = gym.spaces.Box(low=0, high=255,
|
||||||
shape=newshape, dtype=np.uint8)
|
shape=newshape, dtype=np.uint8)
|
||||||
|
|
||||||
def observation(self, frame):
|
def observation(self, frame):
|
||||||
@@ -116,7 +102,7 @@ class Rgb2gray(gym.ObservationWrapper):
|
|||||||
"""
|
"""
|
||||||
gym.ObservationWrapper.__init__(self, env)
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
(oldh, oldw, _oldc) = env.observation_space.shape
|
(oldh, oldw, _oldc) = env.observation_space.shape
|
||||||
self.observation_space = spaces.Box(low=0, high=255,
|
self.observation_space = gym.spaces.Box(low=0, high=255,
|
||||||
shape=(oldh, oldw, 1), dtype=np.uint8)
|
shape=(oldh, oldw, 1), dtype=np.uint8)
|
||||||
|
|
||||||
def observation(self, frame):
|
def observation(self, frame):
|
||||||
@@ -213,8 +199,10 @@ class StartDoingRandomActionsWrapper(gym.Wrapper):
|
|||||||
self.some_random_steps()
|
self.some_random_steps()
|
||||||
return self.last_obs, rew, done, info
|
return self.last_obs, rew, done, info
|
||||||
|
|
||||||
def make_retro(*, game, state, max_episode_steps, **kwargs):
|
def make_retro(*, game, state=None, max_episode_steps=4500, **kwargs):
|
||||||
import retro
|
import retro
|
||||||
|
if state is None:
|
||||||
|
state = retro.State.DEFAULT
|
||||||
env = retro.make(game, state, **kwargs)
|
env = retro.make(game, state, **kwargs)
|
||||||
env = StochasticFrameSkip(env, n=4, stickprob=0.25)
|
env = StochasticFrameSkip(env, n=4, stickprob=0.25)
|
||||||
if max_episode_steps is not None:
|
if max_episode_steps is not None:
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import joblib
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf # pylint: ignore-module
|
import tensorflow as tf # pylint: ignore-module
|
||||||
import copy
|
import copy
|
||||||
@@ -336,6 +335,7 @@ def save_state(fname, sess=None):
|
|||||||
# TODO: ensure there is no subtle differences and remove one
|
# TODO: ensure there is no subtle differences and remove one
|
||||||
|
|
||||||
def save_variables(save_path, variables=None, sess=None):
|
def save_variables(save_path, variables=None, sess=None):
|
||||||
|
import joblib
|
||||||
sess = sess or get_session()
|
sess = sess or get_session()
|
||||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||||
|
|
||||||
@@ -347,6 +347,7 @@ def save_variables(save_path, variables=None, sess=None):
|
|||||||
joblib.dump(save_dict, save_path)
|
joblib.dump(save_dict, save_path)
|
||||||
|
|
||||||
def load_variables(load_path, variables=None, sess=None):
|
def load_variables(load_path, variables=None, sess=None):
|
||||||
|
import joblib
|
||||||
sess = sess or get_session()
|
sess = sess or get_session()
|
||||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||||
|
|
||||||
|
19
baselines/common/wrappers.py
Normal file
19
baselines/common/wrappers.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import gym
|
||||||
|
|
||||||
|
class TimeLimit(gym.Wrapper):
|
||||||
|
def __init__(self, env, max_episode_steps=None):
|
||||||
|
super(TimeLimit, self).__init__(env)
|
||||||
|
self._max_episode_steps = max_episode_steps
|
||||||
|
self._elapsed_steps = 0
|
||||||
|
|
||||||
|
def step(self, ac):
|
||||||
|
observation, reward, done, info = self.env.step(ac)
|
||||||
|
self._elapsed_steps += 1
|
||||||
|
if self._elapsed_steps >= self._max_episode_steps:
|
||||||
|
done = True
|
||||||
|
info['TimeLimit.truncated'] = True
|
||||||
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
self._elapsed_steps = 0
|
||||||
|
return self.env.reset(**kwargs)
|
Reference in New Issue
Block a user