Merge branch 'internal' of github.com:openai/baselines into internal

This commit is contained in:
Peter Zhokhov
2018-11-05 14:07:52 -08:00
3 changed files with 25 additions and 13 deletions

View File

@@ -129,18 +129,26 @@ class ClipRewardEnv(gym.RewardWrapper):
return np.sign(reward) return np.sign(reward)
class WarpFrame(gym.ObservationWrapper): class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, width=84, height=84): def __init__(self, env, width=84, height=84, grayscale=True):
"""Warp frames to 84x84 as done in the Nature paper and later work.""" """Warp frames to 84x84 as done in the Nature paper and later work."""
gym.ObservationWrapper.__init__(self, env) gym.ObservationWrapper.__init__(self, env)
self.width = width self.width = width
self.height = height self.height = height
self.observation_space = spaces.Box(low=0, high=255, self.grayscale = grayscale
shape=(self.height, self.width, 1), dtype=np.uint8) if self.grayscale:
self.observation_space = spaces.Box(low=0, high=255,
shape=(self.height, self.width, 1), dtype=np.uint8)
else:
self.observation_space = spaces.Box(low=0, high=255,
shape=(self.height, self.width, 3), dtype=np.uint8)
def observation(self, frame): def observation(self, frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) if self.grayscale:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
return frame[:, :, None] if self.grayscale:
frame = np.expand_dims(frame, -1)
return frame
class FrameStack(gym.Wrapper): class FrameStack(gym.Wrapper):
def __init__(self, env, k): def __init__(self, env, k):
@@ -156,7 +164,7 @@ class FrameStack(gym.Wrapper):
self.k = k self.k = k
self.frames = deque([], maxlen=k) self.frames = deque([], maxlen=k)
shp = env.observation_space.shape shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype) self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
def reset(self): def reset(self):
ob = self.env.reset() ob = self.env.reset()
@@ -197,7 +205,7 @@ class LazyFrames(object):
def _force(self): def _force(self):
if self._out is None: if self._out is None:
self._out = np.concatenate(self._frames, axis=2) self._out = np.concatenate(self._frames, axis=-1)
self._frames = None self._frames = None
return self._out return self._out

View File

@@ -62,7 +62,7 @@ class CategoricalPdType(PdType):
def pdclass(self): def pdclass(self):
return CategoricalPd return CategoricalPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias) pdparam = _matching_fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam return self.pdfromflat(pdparam), pdparam
def param_shape(self): def param_shape(self):
@@ -82,7 +82,7 @@ class MultiCategoricalPdType(PdType):
return MultiCategoricalPd(self.ncats, flat) return MultiCategoricalPd(self.ncats, flat)
def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias) pdparam = _matching_fc(latent, 'pi', self.ncats.sum(), init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam return self.pdfromflat(pdparam), pdparam
def param_shape(self): def param_shape(self):
@@ -99,7 +99,7 @@ class DiagGaussianPdType(PdType):
return DiagGaussianPd return DiagGaussianPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) mean = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer()) logstd = tf.get_variable(name='pi/logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1) pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
return self.pdfromflat(pdparam), mean return self.pdfromflat(pdparam), mean
@@ -123,7 +123,7 @@ class BernoulliPdType(PdType):
def sample_dtype(self): def sample_dtype(self):
return tf.int32 return tf.int32
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) pdparam = _matching_fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam return self.pdfromflat(pdparam), pdparam
# WRONG SECOND DERIVATIVES # WRONG SECOND DERIVATIVES
@@ -345,3 +345,9 @@ def validate_probtype(probtype, pdparam):
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
print('ok on', probtype, pdparam) print('ok on', probtype, pdparam)
def _matching_fc(tensor, name, size, init_scale, init_bias):
if tensor.shape[-1] == size:
return tensor
else:
return fc(tensor, name, size, init_scale=init_scale, init_bias=init_bias)

View File

@@ -20,8 +20,6 @@ class DummyVecEnv(VecEnv):
env = self.envs[0] env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
obs_space = env.observation_space obs_space = env.observation_space
if isinstance(obs_space, spaces.MultiDiscrete):
obs_space.shape = obs_space.shape[0]
self.keys, shapes, dtypes = obs_space_info(obs_space) self.keys, shapes, dtypes = obs_space_info(obs_space)