update Atari envs to v4 and warn Python 2 users.
This commit is contained in:
@@ -355,43 +355,6 @@ def dropout(x, pkeep, phase=None, mask=None):
|
|||||||
return switch(phase, mask * x, pkeep * x)
|
return switch(phase, mask * x, pkeep * x)
|
||||||
|
|
||||||
|
|
||||||
def batchnorm(x, name, phase, updates, gamma=0.96):
|
|
||||||
k = x.get_shape()[1]
|
|
||||||
runningmean = tf.get_variable(name + "/mean",
|
|
||||||
shape=[1, k],
|
|
||||||
initializer=tf.constant_initializer(0.0),
|
|
||||||
trainable=False)
|
|
||||||
runningvar = tf.get_variable(name + "/var",
|
|
||||||
shape=[1, k],
|
|
||||||
initializer=tf.constant_initializer(1e-4),
|
|
||||||
trainable=False)
|
|
||||||
testy = (x - runningmean) / tf.sqrt(runningvar)
|
|
||||||
|
|
||||||
mean_ = mean(x, axis=0, keepdims=True)
|
|
||||||
var_ = mean(tf.square(x), axis=0, keepdims=True)
|
|
||||||
std = tf.sqrt(var_)
|
|
||||||
trainy = (x - mean_) / std
|
|
||||||
|
|
||||||
updates.extend([
|
|
||||||
tf.assign(runningmean, runningmean * gamma + mean_ * (1 - gamma)),
|
|
||||||
tf.assign(runningvar, runningvar * gamma + var_ * (1 - gamma))
|
|
||||||
])
|
|
||||||
|
|
||||||
y = switch(phase, trainy, testy)
|
|
||||||
|
|
||||||
scaling = tf.get_variable(name + "/scaling",
|
|
||||||
shape=[1, k],
|
|
||||||
initializer=tf.constant_initializer(1.0),
|
|
||||||
trainable=True)
|
|
||||||
|
|
||||||
translation = tf.get_variable(name + "/translation",
|
|
||||||
shape=[1, k],
|
|
||||||
initializer=tf.constant_initializer(0.0),
|
|
||||||
trainable=True)
|
|
||||||
|
|
||||||
return y * scaling + translation
|
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
# Theano-like Function
|
# Theano-like Function
|
||||||
# ================================================================
|
# ================================================================
|
||||||
|
@@ -29,7 +29,7 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def make_env(game_name):
|
def make_env(game_name):
|
||||||
env = gym.make(game_name + "NoFrameskip-v3")
|
env = gym.make(game_name + "NoFrameskip-v4")
|
||||||
env = SimpleMonitor(env)
|
env = SimpleMonitor(env)
|
||||||
env = wrap_dqn(env)
|
env = wrap_dqn(env)
|
||||||
return env
|
return env
|
||||||
|
@@ -57,7 +57,7 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def make_env(game_name):
|
def make_env(game_name):
|
||||||
env = gym.make(game_name + "NoFrameskip-v3")
|
env = gym.make(game_name + "NoFrameskip-v4")
|
||||||
monitored_env = SimpleMonitor(env) # puts rewards and number of steps in info, before environment is wrapped
|
monitored_env = SimpleMonitor(env) # puts rewards and number of steps in info, before environment is wrapped
|
||||||
env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w)
|
env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w)
|
||||||
return env, monitored_env
|
return env, monitored_env
|
||||||
|
@@ -12,7 +12,7 @@ from baselines.deepq.experiments.atari.model import model, dueling_model
|
|||||||
|
|
||||||
|
|
||||||
def make_env(game_name):
|
def make_env(game_name):
|
||||||
env = gym.make(game_name + "NoFrameskip-v3")
|
env = gym.make(game_name + "NoFrameskip-v4")
|
||||||
env_monitored = SimpleMonitor(env)
|
env_monitored = SimpleMonitor(env)
|
||||||
env = wrap_dqn(env_monitored)
|
env = wrap_dqn(env_monitored)
|
||||||
return env_monitored, env
|
return env_monitored, env
|
||||||
|
@@ -5,7 +5,7 @@ from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFram
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = gym.make("PongNoFrameskip-v3")
|
env = gym.make("PongNoFrameskip-v4")
|
||||||
env = ScaledFloatFrame(wrap_dqn(env))
|
env = ScaledFloatFrame(wrap_dqn(env))
|
||||||
act = deepq.load("pong_model.pkl")
|
act = deepq.load("pong_model.pkl")
|
||||||
|
|
||||||
|
@@ -5,7 +5,7 @@ from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFram
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = gym.make("PongNoFrameskip-v3")
|
env = gym.make("PongNoFrameskip-v4")
|
||||||
env = ScaledFloatFrame(wrap_dqn(env))
|
env = ScaledFloatFrame(wrap_dqn(env))
|
||||||
model = deepq.models.cnn_to_mlp(
|
model = deepq.models.cnn_to_mlp(
|
||||||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
||||||
|
11
setup.py
11
setup.py
@@ -1,14 +1,15 @@
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
import os
|
import sys
|
||||||
|
|
||||||
|
|
||||||
repo_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
|
if sys.version_info.major != 3:
|
||||||
|
print("This Python is only compatible with Python 3, but you are running "
|
||||||
|
"Python {}. The installation will likely fail.".format(sys.version_info.major))
|
||||||
|
|
||||||
setup(name='baselines',
|
setup(name='baselines',
|
||||||
packages=[package for package in find_packages()
|
packages=[package for package in find_packages()
|
||||||
if package.startswith('baselines')],
|
if package.startswith('baselines')],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
'gym',
|
||||||
'scipy',
|
'scipy',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'joblib',
|
'joblib',
|
||||||
@@ -22,4 +23,4 @@ setup(name='baselines',
|
|||||||
author="OpenAI",
|
author="OpenAI",
|
||||||
url='https://github.com/openai/baselines',
|
url='https://github.com/openai/baselines',
|
||||||
author_email="gym@openai.com",
|
author_email="gym@openai.com",
|
||||||
version="0.1.0")
|
version="0.1.3")
|
||||||
|
Reference in New Issue
Block a user