update Atari envs to v4 and warn Python 2 users.
This commit is contained in:
@@ -355,43 +355,6 @@ def dropout(x, pkeep, phase=None, mask=None):
|
||||
return switch(phase, mask * x, pkeep * x)
|
||||
|
||||
|
||||
def batchnorm(x, name, phase, updates, gamma=0.96):
|
||||
k = x.get_shape()[1]
|
||||
runningmean = tf.get_variable(name + "/mean",
|
||||
shape=[1, k],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
trainable=False)
|
||||
runningvar = tf.get_variable(name + "/var",
|
||||
shape=[1, k],
|
||||
initializer=tf.constant_initializer(1e-4),
|
||||
trainable=False)
|
||||
testy = (x - runningmean) / tf.sqrt(runningvar)
|
||||
|
||||
mean_ = mean(x, axis=0, keepdims=True)
|
||||
var_ = mean(tf.square(x), axis=0, keepdims=True)
|
||||
std = tf.sqrt(var_)
|
||||
trainy = (x - mean_) / std
|
||||
|
||||
updates.extend([
|
||||
tf.assign(runningmean, runningmean * gamma + mean_ * (1 - gamma)),
|
||||
tf.assign(runningvar, runningvar * gamma + var_ * (1 - gamma))
|
||||
])
|
||||
|
||||
y = switch(phase, trainy, testy)
|
||||
|
||||
scaling = tf.get_variable(name + "/scaling",
|
||||
shape=[1, k],
|
||||
initializer=tf.constant_initializer(1.0),
|
||||
trainable=True)
|
||||
|
||||
translation = tf.get_variable(name + "/translation",
|
||||
shape=[1, k],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
trainable=True)
|
||||
|
||||
return y * scaling + translation
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Theano-like Function
|
||||
# ================================================================
|
||||
|
@@ -29,7 +29,7 @@ def parse_args():
|
||||
|
||||
|
||||
def make_env(game_name):
|
||||
env = gym.make(game_name + "NoFrameskip-v3")
|
||||
env = gym.make(game_name + "NoFrameskip-v4")
|
||||
env = SimpleMonitor(env)
|
||||
env = wrap_dqn(env)
|
||||
return env
|
||||
|
@@ -57,7 +57,7 @@ def parse_args():
|
||||
|
||||
|
||||
def make_env(game_name):
|
||||
env = gym.make(game_name + "NoFrameskip-v3")
|
||||
env = gym.make(game_name + "NoFrameskip-v4")
|
||||
monitored_env = SimpleMonitor(env) # puts rewards and number of steps in info, before environment is wrapped
|
||||
env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w)
|
||||
return env, monitored_env
|
||||
|
@@ -12,7 +12,7 @@ from baselines.deepq.experiments.atari.model import model, dueling_model
|
||||
|
||||
|
||||
def make_env(game_name):
|
||||
env = gym.make(game_name + "NoFrameskip-v3")
|
||||
env = gym.make(game_name + "NoFrameskip-v4")
|
||||
env_monitored = SimpleMonitor(env)
|
||||
env = wrap_dqn(env_monitored)
|
||||
return env_monitored, env
|
||||
|
@@ -5,7 +5,7 @@ from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFram
|
||||
|
||||
|
||||
def main():
|
||||
env = gym.make("PongNoFrameskip-v3")
|
||||
env = gym.make("PongNoFrameskip-v4")
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
act = deepq.load("pong_model.pkl")
|
||||
|
||||
|
@@ -5,7 +5,7 @@ from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFram
|
||||
|
||||
|
||||
def main():
|
||||
env = gym.make("PongNoFrameskip-v3")
|
||||
env = gym.make("PongNoFrameskip-v4")
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
model = deepq.models.cnn_to_mlp(
|
||||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
||||
|
11
setup.py
11
setup.py
@@ -1,14 +1,15 @@
|
||||
from setuptools import setup, find_packages
|
||||
import os
|
||||
|
||||
|
||||
repo_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
import sys
|
||||
|
||||
if sys.version_info.major != 3:
|
||||
print("This Python is only compatible with Python 3, but you are running "
|
||||
"Python {}. The installation will likely fail.".format(sys.version_info.major))
|
||||
|
||||
setup(name='baselines',
|
||||
packages=[package for package in find_packages()
|
||||
if package.startswith('baselines')],
|
||||
install_requires=[
|
||||
'gym',
|
||||
'scipy',
|
||||
'tqdm',
|
||||
'joblib',
|
||||
@@ -22,4 +23,4 @@ setup(name='baselines',
|
||||
author="OpenAI",
|
||||
url='https://github.com/openai/baselines',
|
||||
author_email="gym@openai.com",
|
||||
version="0.1.0")
|
||||
version="0.1.3")
|
||||
|
Reference in New Issue
Block a user