Compare commits
2 Commits
peterz_ben
...
peterz_cod
Author | SHA1 | Date | |
---|---|---|---|
|
841da92f4d | ||
|
624231827c |
@@ -1 +1 @@
|
|||||||
ppo2
|
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@
|
|||||||
.pytest_cache
|
.pytest_cache
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.idea
|
.idea
|
||||||
|
.coverage
|
||||||
|
|
||||||
# Setuptools distribution and build folders.
|
# Setuptools distribution and build folders.
|
||||||
/dist/
|
/dist/
|
||||||
|
@@ -139,3 +139,4 @@ To cite this repository in publications:
|
|||||||
journal = {GitHub repository},
|
journal = {GitHub repository},
|
||||||
howpublished = {\url{https://github.com/openai/baselines}},
|
howpublished = {\url{https://github.com/openai/baselines}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -156,7 +156,7 @@ class FrameStack(gym.Wrapper):
|
|||||||
self.k = k
|
self.k = k
|
||||||
self.frames = deque([], maxlen=k)
|
self.frames = deque([], maxlen=k)
|
||||||
shp = env.observation_space.shape
|
shp = env.observation_space.shape
|
||||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
|
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
ob = self.env.reset()
|
ob = self.env.reset()
|
||||||
@@ -176,6 +176,7 @@ class FrameStack(gym.Wrapper):
|
|||||||
class ScaledFloatFrame(gym.ObservationWrapper):
|
class ScaledFloatFrame(gym.ObservationWrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
gym.ObservationWrapper.__init__(self, env)
|
gym.ObservationWrapper.__init__(self, env)
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# careful! This undoes the memory optimization, use
|
# careful! This undoes the memory optimization, use
|
||||||
|
@@ -138,7 +138,7 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
def network_fn(X):
|
def network_fn(X):
|
||||||
out = X
|
out = tf.cast(X, tf.float32) / 255.
|
||||||
with tf.variable_scope("convnet"):
|
with tf.variable_scope("convnet"):
|
||||||
for num_outputs, kernel_size, stride in convs:
|
for num_outputs, kernel_size, stride in convs:
|
||||||
out = layers.convolution2d(out,
|
out = layers.convolution2d(out,
|
||||||
|
3
setup.py
3
setup.py
@@ -25,7 +25,8 @@ setup(name='baselines',
|
|||||||
extras_require={
|
extras_require={
|
||||||
'test': [
|
'test': [
|
||||||
'filelock',
|
'filelock',
|
||||||
'pytest'
|
'pytest',
|
||||||
|
'pytest-cov',
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
description='OpenAI baselines: high quality implementations of reinforcement learning algorithms',
|
description='OpenAI baselines: high quality implementations of reinforcement learning algorithms',
|
||||||
|
Reference in New Issue
Block a user