Initial commit

This commit is contained in:
Szymon Sidor
2017-05-17 14:41:46 -07:00
commit 958810ed1e
35 changed files with 3839 additions and 0 deletions

29
.gitignore vendored Normal file
View File

@@ -0,0 +1,29 @@
*.swp
*.pyc
*.py~
.DS_Store
.idea
# Setuptools distribution and build folders.
/dist/
/build
# Virtualenv
/env
# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info
*.sublime-project
*.sublime-workspace
.idea
logs/
.ipynb_checkpoints
ghostdriver.log
htmlcov
junk

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
The MIT License
Copyright (c) 2017 OpenAI (http://openai.com)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

66
README.md Normal file
View File

@@ -0,0 +1,66 @@
<img src="data/logo.jpg" width=25% align="right" />
# Baselines
We're releasing OpenAI Baselines, a set of high-quality implementations of reinforcement learning algorithms. To start, we're making available an open source version of Deep Q-Learning and three of its variants.
These algorithms will make it easier for the research community to replicate, refine, and identify new ideas, and will create good baselines to build research on top of. Our DQN implementation and its variants are roughly on par with the scores in published papers. We expect they will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones.
You can install it by typing:
```bash
pip install baselines
```
## If you are curious.
##### Train a Cartpole agent and watch it play once it converges!
Here's a list of commands to run to quickly get a working example:
<img src="data/cartpole.gif" width="25%" />
```bash
# Train model and save the results to cartpole_model.pkl
python -m baselines.deepq.experiments.train_cartpole
# Load the model saved in cartpole_model.pkl and visualize the learned policy
python -m baselines.deepq.experiments.enjoy_cartpole
```
Be sure to check out the source code of [both](baselines/deepq/experiments/train_cartpole.py) [files](baselines/deepq/experiments/enjoy_cartpole.py)!
## If you wish to apply DQN to solve a problem.
Check out our simple agented trained with one stop shop `deepq.learn` function.
- `baselines/deepq/experiments/train_cartpole.py` - train a Cartpole agent.
- `baselines/deepq/experiments/train_pong.py` - train a Pong agent using convolutional neural networks.
In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. For both of the files listed above there are complimentary files `enjoy_cartpole.py` and `enjoy_pong.py` respectively, that load and visualize the learned policy.
## If you wish to experiment with the algorithm
##### Check out the examples
- `baselines/deepq/experiments/custom_cartpole.py` - Cartpole training with more fine grained control over the internals of DQN algorithm.
- `baselines/deepq/experiments/atari/train.py` - more robust setup for training at scale.
##### Download a pretrained Atari agent
For some research projects it is sometimes useful to have an already trained agent handy. There's a variety of models to choose from. You can list them all by running:
```bash
python -m baselines.deepq.experiments.download_model
```
Once you pick a model, you can download it and visualize the learned policy. Be sure to pass `--dueling` flag to visualization script when using dueling models.
```bash
python -m baselines.deepq.experiments.atari.download_model --blob model-atari-prior-duel-breakout-1 --model-dir /tmp/models
python -m baselines.deepq.experiments.atari.enjoy --model-dir /tmp/models/model-atari-prior-duel-breakout-1 --env Breakout --dueling
```

0
baselines/__init__.py Normal file
View File

View File

@@ -0,0 +1,4 @@
from baselines.common.misc_util import *

View File

@@ -0,0 +1,240 @@
import cv2
import gym
import numpy as np
from collections import deque
from gym import spaces
class NoopResetEnv(gym.Wrapper):
def __init__(self, env=None, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
super(NoopResetEnv, self).__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def _reset(self):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = np.random.randint(1, self.noop_max + 1)
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
if done:
obs = self.env.reset()
return obs
class FireResetEnv(gym.Wrapper):
def __init__(self, env=None):
"""For environments where the user need to press FIRE for the game to start."""
super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env=None):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
super(EpisodicLifeEnv, self).__init__(env)
self.lives = 0
self.was_real_done = True
self.was_real_reset = False
def _step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def _reset(self):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
self.was_real_reset = True
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.was_real_reset = False
self.lives = self.env.unwrapped.ale.lives()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
def _step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def _reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class ProcessFrame84(gym.ObservationWrapper):
def __init__(self, env=None):
super(ProcessFrame84, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
def _observation(self, obs):
return ProcessFrame84.process(obs)
@staticmethod
def process(frame):
resized_screen = None
if frame.size == 210 * 160 * 3:
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
elif frame.size == 250 * 160 * 3:
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
class ClippedRewardsWrapper(gym.RewardWrapper):
def _reward(self, reward):
"""Change all the positive rewards to 1, negative to -1 and keep zero."""
return np.sign(reward)
class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only stored once.
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
buffers.
This object should only be converted to numpy array before being passed to the model.
You'd not belive how complex the previous solution was."""
self._frames = frames
def __array__(self, dtype=None):
out = np.concatenate(self._frames, axis=2)
if dtype is not None:
out = out.astype(dtype)
return out
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames.
Returns lazy array, which is much more memory efficient.
See Also
--------
baselines.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
def _reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))
class ScaledFloatFrame(gym.ObservationWrapper):
def _observation(self, obs):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(obs).astype(np.float32) / 255.0
def wrap_dqn(env):
"""Apply a common set of wrappers for Atari games."""
assert 'NoFrameskip' in env.spec.id
env = EpisodicLifeEnv(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ProcessFrame84(env)
env = FrameStack(env, 4)
env = ClippedRewardsWrapper(env)
return env
class A2cProcessFrame(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
def _step(self, action):
ob, reward, done, info = self.env.step(action)
return A2cProcessFrame.process(ob), reward, done, info
def _reset(self):
return A2cProcessFrame.process(self.env.reset())
@staticmethod
def process(frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
return frame.reshape(84, 84, 1)

View File

@@ -0,0 +1,149 @@
import os
import tempfile
import zipfile
from azure.common import AzureMissingResourceHttpError
from azure.storage.blob import BlobService
from shutil import unpack_archive
from threading import Event
"""TODOS:
- use Azure snapshots instead of hacky backups
"""
def fixed_list_blobs(service, *args, **kwargs):
"""By defualt list_containers only returns a subset of results.
This function attempts to fix this.
"""
res = []
next_marker = None
while next_marker is None or len(next_marker) > 0:
kwargs['marker'] = next_marker
gen = service.list_blobs(*args, **kwargs)
for b in gen:
res.append(b.name)
next_marker = gen.next_marker
return res
def make_archive(source_path, dest_path):
if source_path.endswith(os.path.sep):
source_path = source_path.rstrip(os.path.sep)
prefix_path = os.path.dirname(source_path)
with zipfile.ZipFile(dest_path, "w", compression=zipfile.ZIP_STORED) as zf:
if os.path.isdir(source_path):
for dirname, subdirs, files in os.walk(source_path):
zf.write(dirname, os.path.relpath(dirname, prefix_path))
for filename in files:
filepath = os.path.join(dirname, filename)
zf.write(filepath, os.path.relpath(filepath, prefix_path))
else:
zf.write(source_path, os.path.relpath(source_path, prefix_path))
class Container(object):
services = {}
def __init__(self, account_name, account_key, container_name, maybe_create=False):
self._account_name = account_name
self._container_name = container_name
if account_name not in Container.services:
Container.services[account_name] = BlobService(account_name, account_key)
self._service = Container.services[account_name]
if maybe_create:
self._service.create_container(self._container_name, fail_on_exist=False)
def put(self, source_path, blob_name, callback=None):
"""Upload a file or directory from `source_path` to azure blob `blob_name`.
Upload progress can be traced by an optional callback.
"""
upload_done = Event()
def progress_callback(current, total):
if callback:
callback(current, total)
if current >= total:
upload_done.set()
# Attempt to make backup if an existing version is already available
try:
x_ms_copy_source = "https://{}.blob.core.windows.net/{}/{}".format(
self._account_name,
self._container_name,
blob_name
)
self._service.copy_blob(
container_name=self._container_name,
blob_name=blob_name + ".backup",
x_ms_copy_source=x_ms_copy_source
)
except AzureMissingResourceHttpError:
pass
with tempfile.TemporaryDirectory() as td:
arcpath = os.path.join(td, "archive.zip")
make_archive(source_path, arcpath)
self._service.put_block_blob_from_path(
container_name=self._container_name,
blob_name=blob_name,
file_path=arcpath,
max_connections=4,
progress_callback=progress_callback,
max_retries=10)
upload_done.wait()
def get(self, dest_path, blob_name, callback=None):
"""Download a file or directory to `dest_path` to azure blob `blob_name`.
Warning! If directory is downloaded the `dest_path` is the parent directory.
Upload progress can be traced by an optional callback.
"""
download_done = Event()
def progress_callback(current, total):
if callback:
callback(current, total)
if current >= total:
download_done.set()
with tempfile.TemporaryDirectory() as td:
arcpath = os.path.join(td, "archive.zip")
for backup_blob_name in [blob_name, blob_name + '.backup']:
try:
blob_size = self._service.get_blob_properties(
blob_name=backup_blob_name,
container_name=self._container_name
)['content-length']
if int(blob_size) > 0:
self._service.get_blob_to_path(
container_name=self._container_name,
blob_name=backup_blob_name,
file_path=arcpath,
max_connections=4,
progress_callback=progress_callback,
max_retries=10)
unpack_archive(arcpath, dest_path)
download_done.wait()
return True
except AzureMissingResourceHttpError:
pass
return False
def list(self, prefix=None):
"""List all blobs in the container."""
return fixed_list_blobs(self._service, self._container_name, prefix=prefix)
def exists(self, blob_name):
"""Returns true if `blob_name` exists in container."""
try:
self._service.get_blob_properties(
blob_name=blob_name,
container_name=self._container_name
)
return True
except AzureMissingResourceHttpError:
return False

View File

@@ -0,0 +1,327 @@
import gym
import numpy as np
import os
import pickle
import random
import tempfile
import time
import zipfile
def zipsame(*seqs):
L = len(seqs[0])
assert all(len(seq) == L for seq in seqs[1:])
return zip(*seqs)
def unpack(seq, sizes):
"""
Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'.
None = just one bare element, not a list
Example:
unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6])
"""
seq = list(seq)
it = iter(seq)
assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes)
for size in sizes:
if size is None:
yield it.__next__()
else:
li = []
for _ in range(size):
li.append(it.__next__())
yield li
class EzPickle(object):
"""Objects that are pickled and unpickled via their constructor
arguments.
Example usage:
class Dog(Animal, EzPickle):
def __init__(self, furcolor, tailkind="bushy"):
Animal.__init__()
EzPickle.__init__(furcolor, tailkind)
...
When this object is unpickled, a new Dog will be constructed by passing the provided
furcolor and tailkind into the constructor. However, philosophers are still not sure
whether it is still the same dog.
This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
and Atari.
"""
def __init__(self, *args, **kwargs):
self._ezpickle_args = args
self._ezpickle_kwargs = kwargs
def __getstate__(self):
return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs}
def __setstate__(self, d):
out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
self.__dict__.update(out.__dict__)
def set_global_seeds(i):
try:
import tensorflow as tf
except ImportError:
pass
else:
tf.set_random_seed(i)
np.random.seed(i)
random.seed(i)
def pretty_eta(seconds_left):
"""Print the number of seconds in human readable format.
Examples:
2 days
2 hours and 37 minutes
less than a minute
Paramters
---------
seconds_left: int
Number of seconds to be converted to the ETA
Returns
-------
eta: str
String representing the pretty ETA.
"""
minutes_left = seconds_left // 60
seconds_left %= 60
hours_left = minutes_left // 60
minutes_left %= 60
days_left = hours_left // 24
hours_left %= 24
def helper(cnt, name):
return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else ''))
if days_left > 0:
msg = helper(days_left, 'day')
if hours_left > 0:
msg += ' and ' + helper(hours_left, 'hour')
return msg
if hours_left > 0:
msg = helper(hours_left, 'hour')
if minutes_left > 0:
msg += ' and ' + helper(minutes_left, 'minute')
return msg
if minutes_left > 0:
return helper(minutes_left, 'minute')
return 'less than a minute'
class RunningAvg(object):
def __init__(self, gamma, init_value=None):
"""Keep a running estimate of a quantity. This is a bit like mean
but more sensitive to recent changes.
Parameters
----------
gamma: float
Must be between 0 and 1, where 0 is the most sensitive to recent
changes.
init_value: float or None
Initial value of the estimate. If None, it will be set on the first update.
"""
self._value = init_value
self._gamma = gamma
def update(self, new_val):
"""Update the estimate.
Parameters
----------
new_val: float
new observated value of estimated quantity.
"""
if self._value is None:
self._value = new_val
else:
self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val
def __float__(self):
"""Get the current estimate"""
return self._value
class SimpleMonitor(gym.Wrapper):
def __init__(self, env=None):
"""Adds two qunatities to info returned by every step:
num_steps: int
Number of steps takes so far
rewards: [float]
All the cumulative rewards for the episodes completed so far.
"""
super().__init__(env)
# current episode state
self._current_reward = None
self._num_steps = None
# temporary monitor state that we do not save
self._time_offset = None
self._total_steps = None
# monitor state
self._episode_rewards = []
self._episode_lengths = []
self._episode_end_times = []
def _reset(self):
obs = self.env.reset()
# recompute temporary state if needed
if self._time_offset is None:
self._time_offset = time.time()
if len(self._episode_end_times) > 0:
self._time_offset -= self._episode_end_times[-1]
if self._total_steps is None:
self._total_steps = sum(self._episode_lengths)
# update monitor state
if self._current_reward is not None:
self._episode_rewards.append(self._current_reward)
self._episode_lengths.append(self._num_steps)
self._episode_end_times.append(time.time() - self._time_offset)
# reset episode state
self._current_reward = 0
self._num_steps = 0
return obs
def _step(self, action):
obs, rew, done, info = self.env.step(action)
self._current_reward += rew
self._num_steps += 1
self._total_steps += 1
info['steps'] = self._total_steps
info['rewards'] = self._episode_rewards
return (obs, rew, done, info)
def get_state(self):
return {
'env_id': self.env.unwrapped.spec.id,
'episode_data': {
'episode_rewards': self._episode_rewards,
'episode_lengths': self._episode_lengths,
'episode_end_times': self._episode_end_times,
'initial_reset_time': 0,
}
}
def set_state(self, state):
assert state['env_id'] == self.env.unwrapped.spec.id
ed = state['episode_data']
self._episode_rewards = ed['episode_rewards']
self._episode_lengths = ed['episode_lengths']
self._episode_end_times = ed['episode_end_times']
def boolean_flag(parser, name, default=False, help=None):
"""Add a boolean flag to argparse parser.
Parameters
----------
parser: argparse.Parser
parser to add the flag to
name: str
--<name> will enable the flag, while --no-<name> will disable it
default: bool or None
default value of the flag
help: str
help string for the flag
"""
parser.add_argument("--" + name, action="store_true", default=default, help=help)
parser.add_argument("--no-" + name, action="store_false", dest=name)
def get_wrapper_by_name(env, classname):
"""Given an a gym environment possibly wrapped multiple times, returns a wrapper
of class named classname or raises ValueError if no such wrapper was applied
Parameters
----------
env: gym.Env of gym.Wrapper
gym environment
classname: str
name of the wrapper
Returns
-------
wrapper: gym.Wrapper
wrapper named classname
"""
currentenv = env
while True:
if classname == currentenv.class_name():
return currentenv
elif isinstance(currentenv, gym.Wrapper):
currentenv = currentenv.env
else:
raise ValueError("Couldn't find wrapper named %s" % classname)
def relatively_safe_pickle_dump(obj, path, compression=False):
"""This is just like regular pickle dump, except from the fact that failure cases are
different:
- It's never possible that we end up with a pickle in corrupted state.
- If a there was a different file at the path, that file will remain unchanged in the
even of failure (provided that filesystem rename is atomic).
- it is sometimes possible that we end up with useless temp file which needs to be
deleted manually (it will be removed automatically on the next function call)
The indended use case is periodic checkpoints of experiment state, such that we never
corrupt previous checkpoints if the current one fails.
Parameters
----------
obj: object
object to pickle
path: str
path to the output file
compression: bool
if true pickle will be compressed
"""
temp_storage = path + ".relatively_safe"
if compression:
# Using gzip here would be simpler, but the size is limited to 2GB
with tempfile.NamedTemporaryFile() as uncompressed_file:
pickle.dump(obj, uncompressed_file)
with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip:
myzip.write(uncompressed_file.name, "data")
else:
with open(temp_storage, "wb") as f:
pickle.dump(obj, f)
os.rename(temp_storage, path)
def pickle_load(path, compression=False):
"""Unpickle a possible compressed pickle.
Parameters
----------
path: str
path to the output file
compression: bool
if true assumes that pickle was compressed when created and attempts decompression.
Returns
-------
obj: object
the unpickled object
"""
if compression:
with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip:
with myzip.open("data") as f:
return pickle.load(f)
else:
with open(path, "rb") as f:
return pickle.load(f)

View File

@@ -0,0 +1,99 @@
"""This file is used for specifying various schedules that evolve over
time throughout the execution of the algorithm, such as:
- learning rate for the optimizer
- exploration epsilon for the epsilon greedy exploration strategy
- beta parameter for beta parameter in prioritized replay
Each schedule has a function `value(t)` which returns the current value
of the parameter given the timestep t of the optimization procedure.
"""
class Schedule(object):
def value(self, t):
"""Value of the schedule at time t"""
raise NotImplementedError()
class ConstantSchedule(object):
def __init__(self, value):
"""Value remains constant over time.
Parameters
----------
value: float
Constant value of the schedule
"""
self._v = value
def value(self, t):
"""See Schedule.value"""
return self._v
def linear_interpolation(l, r, alpha):
return l + alpha * (r - l)
class PiecewiseSchedule(object):
def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
"""Piecewise schedule.
endpoints: [(int, int)]
list of pairs `(time, value)` meanining that schedule should output
`value` when `t==time`. All the values for time must be sorted in
an increasing order. When t is between two times, e.g. `(time_a, value_a)`
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
`interpolation(value_a, value_b, alpha)` where alpha is a fraction of
time passed between `time_a` and `time_b` for time `t`.
interpolation: lambda float, float, float: float
a function that takes value to the left and to the right of t according
to the `endpoints`. Alpha is the fraction of distance from left endpoint to
right endpoint that t has covered. See linear_interpolation for example.
outside_value: float
if the value is requested outside of all the intervals sepecified in
`endpoints` this value is returned. If None then AssertionError is
raised when outside value is requested.
"""
idxes = [e[0] for e in endpoints]
assert idxes == sorted(idxes)
self._interpolation = interpolation
self._outside_value = outside_value
self._endpoints = endpoints
def value(self, t):
"""See Schedule.value"""
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
if l_t <= t and t < r_t:
alpha = float(t - l_t) / (r_t - l_t)
return self._interpolation(l, r, alpha)
# t does not belong to any of the pieces, so doom.
assert self._outside_value is not None
return self._outside_value
class LinearSchedule(object):
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
"""Linear interpolation between initial_p and final_p over
schedule_timesteps. After this many timesteps pass final_p is
returned.
Parameters
----------
schedule_timesteps: int
Number of timesteps for which to linearly anneal initial_p
to final_p
initial_p: float
initial output value
final_p: float
final output value
"""
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
def value(self, t):
"""See Schedule.value"""
fraction = min(float(t) / self.schedule_timesteps, 1.0)
return self.initial_p + fraction * (self.final_p - self.initial_p)

View File

@@ -0,0 +1,146 @@
import operator
class SegmentTree(object):
def __init__(self, capacity, operation, neutral_element):
"""Build a Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two
important differences:
a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1).
b) user has access to an efficient `reduce`
operation which reduces `operation` over
a contiguous subsequence of items in the
array.
Paramters
---------
capacity: int
Total size of the array - must be a power of two.
operation: lambda obj, obj -> obj
and operation for combining elements (eg. sum, max)
must for a mathematical group together with the set of
possible values for array elements.
neutral_element: obj
neutral element for the operation above. eg. float('-inf')
for max and 0 for sum.
"""
assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self._value[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._reduce_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
else:
return self._operation(
self._reduce_helper(start, mid, 2 * node, node_start, mid),
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
)
def reduce(self, start=0, end=None):
"""Returns result of applying `self.operation`
to a contiguous subsequence of the array.
self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
Parameters
----------
start: int
beginning of the subsequence
end: int
end of the subsequences
Returns
-------
reduced: obj
result of reducing self.operation over the specified range of array elements.
"""
if end is None:
end = self._capacity
if end < 0:
end += self._capacity
end -= 1
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
def __setitem__(self, idx, val):
# index of the leaf
idx += self._capacity
self._value[idx] = val
idx //= 2
while idx >= 1:
self._value[idx] = self._operation(
self._value[2 * idx],
self._value[2 * idx + 1]
)
idx //= 2
def __getitem__(self, idx):
assert 0 <= idx < self._capacity
return self._value[self._capacity + idx]
class SumSegmentTree(SegmentTree):
def __init__(self, capacity):
super(SumSegmentTree, self).__init__(
capacity=capacity,
operation=operator.add,
neutral_element=0.0
)
def sum(self, start=0, end=None):
"""Returns arr[start] + ... + arr[end]"""
return super(SumSegmentTree, self).reduce(start, end)
def find_prefixsum_idx(self, prefixsum):
"""Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
Parameters
----------
perfixsum: float
upperbound on the sum of array prefix
Returns
-------
idx: int
highest index satisfying the prefixsum constraint
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
idx = 1
while idx < self._capacity: # while non-leaf
if self._value[2 * idx] > prefixsum:
idx = 2 * idx
else:
prefixsum -= self._value[2 * idx]
idx = 2 * idx + 1
return idx - self._capacity
class MinSegmentTree(SegmentTree):
def __init__(self, capacity):
super(MinSegmentTree, self).__init__(
capacity=capacity,
operation=min,
neutral_element=float('inf')
)
def min(self, start=0, end=None):
"""Returns min(arr[start], ..., arr[end])"""
return super(MinSegmentTree, self).reduce(start, end)

View File

@@ -0,0 +1,26 @@
import numpy as np
from baselines.common.schedules import ConstantSchedule, PiecewiseSchedule
def test_piecewise_schedule():
ps = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], outside_value=500)
assert np.isclose(ps.value(-10), 500)
assert np.isclose(ps.value(0), 150)
assert np.isclose(ps.value(5), 200)
assert np.isclose(ps.value(9), 80)
assert np.isclose(ps.value(50), 50)
assert np.isclose(ps.value(80), 50)
assert np.isclose(ps.value(150), 0)
assert np.isclose(ps.value(175), -25)
assert np.isclose(ps.value(201), 500)
assert np.isclose(ps.value(500), 500)
assert np.isclose(ps.value(200 - 1e-10), -50)
def test_constant_schedule():
cs = ConstantSchedule(5)
for i in range(-100, 100):
assert np.isclose(cs.value(i), 5)

View File

@@ -0,0 +1,103 @@
import numpy as np
from baselines.common.segment_tree import SumSegmentTree, MinSegmentTree
def test_tree_set():
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[3] = 3.0
assert np.isclose(tree.sum(), 4.0)
assert np.isclose(tree.sum(0, 2), 0.0)
assert np.isclose(tree.sum(0, 3), 1.0)
assert np.isclose(tree.sum(2, 3), 1.0)
assert np.isclose(tree.sum(2, -1), 1.0)
assert np.isclose(tree.sum(2, 4), 4.0)
def test_tree_set_overlap():
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[2] = 3.0
assert np.isclose(tree.sum(), 3.0)
assert np.isclose(tree.sum(2, 3), 3.0)
assert np.isclose(tree.sum(2, -1), 3.0)
assert np.isclose(tree.sum(2, 4), 3.0)
assert np.isclose(tree.sum(1, 2), 0.0)
def test_prefixsum_idx():
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[3] = 3.0
assert tree.find_prefixsum_idx(0.0) == 2
assert tree.find_prefixsum_idx(0.5) == 2
assert tree.find_prefixsum_idx(0.99) == 2
assert tree.find_prefixsum_idx(1.01) == 3
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(4.00) == 3
def test_prefixsum_idx2():
tree = SumSegmentTree(4)
tree[0] = 0.5
tree[1] = 1.0
tree[2] = 1.0
tree[3] = 3.0
assert tree.find_prefixsum_idx(0.00) == 0
assert tree.find_prefixsum_idx(0.55) == 1
assert tree.find_prefixsum_idx(0.99) == 1
assert tree.find_prefixsum_idx(1.51) == 2
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(5.50) == 3
def test_max_interval_tree():
tree = MinSegmentTree(4)
tree[0] = 1.0
tree[2] = 0.5
tree[3] = 3.0
assert np.isclose(tree.min(), 0.5)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.5)
assert np.isclose(tree.min(0, -1), 0.5)
assert np.isclose(tree.min(2, 4), 0.5)
assert np.isclose(tree.min(3, 4), 3.0)
tree[2] = 0.7
assert np.isclose(tree.min(), 0.7)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.7)
assert np.isclose(tree.min(0, -1), 0.7)
assert np.isclose(tree.min(2, 4), 0.7)
assert np.isclose(tree.min(3, 4), 3.0)
tree[2] = 4.0
assert np.isclose(tree.min(), 1.0)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 1.0)
assert np.isclose(tree.min(0, -1), 1.0)
assert np.isclose(tree.min(2, 4), 3.0)
assert np.isclose(tree.min(2, 3), 4.0)
assert np.isclose(tree.min(2, -1), 4.0)
assert np.isclose(tree.min(3, 4), 3.0)
if __name__ == '__main__':
test_tree_set()
test_tree_set_overlap()
test_prefixsum_idx()
test_prefixsum_idx2()
test_max_interval_tree()

View File

@@ -0,0 +1,69 @@
# tests for tf_util
import tensorflow as tf
from baselines.common.tf_util import (
function,
initialize,
set_value,
single_threaded_session
)
def test_set_value():
a = tf.Variable(42.)
with single_threaded_session():
set_value(a, 5)
assert a.eval() == 5
g = tf.get_default_graph()
g.finalize()
set_value(a, 6)
assert a.eval() == 6
# test the test
try:
assert a.eval() == 7
except AssertionError:
pass
else:
assert False, "assertion should have failed"
def test_function():
tf.reset_default_graph()
x = tf.placeholder(tf.int32, (), name="x")
y = tf.placeholder(tf.int32, (), name="y")
z = 3 * x + 2 * y
lin = function([x, y], z, givens={y: 0})
with single_threaded_session():
initialize()
assert lin(2) == 6
assert lin(x=3) == 9
assert lin(2, 2) == 10
assert lin(x=2, y=3) == 12
def test_multikwargs():
tf.reset_default_graph()
x = tf.placeholder(tf.int32, (), name="x")
with tf.variable_scope("other"):
x2 = tf.placeholder(tf.int32, (), name="x")
z = 3 * x + 2 * x2
lin = function([x, x2], z, givens={x2: 0})
with single_threaded_session():
initialize()
assert lin(2) == 6
assert lin(2, 2) == 10
expt_caught = False
try:
lin(x=2)
except AssertionError:
expt_caught = True
assert expt_caught
if __name__ == '__main__':
test_set_value()
test_function()
test_multikwargs()

789
baselines/common/tf_util.py Normal file
View File

@@ -0,0 +1,789 @@
import numpy as np
import tensorflow as tf # pylint: ignore-module
import builtins
import functools
import copy
import os
import collections
# ================================================================
# Make consistent with numpy
# ================================================================
clip = tf.clip_by_value
def sum(x, axis=None, keepdims=False):
axis = None if axis is None else [axis]
return tf.reduce_sum(x, axis=axis, keep_dims=keepdims)
def mean(x, axis=None, keepdims=False):
axis = None if axis is None else [axis]
return tf.reduce_mean(x, axis=axis, keep_dims=keepdims)
def var(x, axis=None, keepdims=False):
meanx = mean(x, axis=axis, keepdims=keepdims)
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
def std(x, axis=None, keepdims=False):
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
def max(x, axis=None, keepdims=False):
axis = None if axis is None else [axis]
return tf.reduce_max(x, axis=axis, keep_dims=keepdims)
def min(x, axis=None, keepdims=False):
axis = None if axis is None else [axis]
return tf.reduce_min(x, axis=axis, keep_dims=keepdims)
def concatenate(arrs, axis=0):
return tf.concat(axis=axis, values=arrs)
def argmax(x, axis=None):
return tf.argmax(x, axis=axis)
def switch(condition, then_expression, else_expression):
'''Switches between two operations depending on a scalar value (int or bool).
Note that both `then_expression` and `else_expression`
should be symbolic tensors of the *same shape*.
# Arguments
condition: scalar tensor.
then_expression: TensorFlow operation.
else_expression: TensorFlow operation.
'''
x_shape = copy.copy(then_expression.get_shape())
x = tf.cond(tf.cast(condition, 'bool'),
lambda: then_expression,
lambda: else_expression)
x.set_shape(x_shape)
return x
# ================================================================
# Extras
# ================================================================
def l2loss(params):
if len(params) == 0:
return tf.constant(0.0)
else:
return tf.add_n([sum(tf.square(p)) for p in params])
def lrelu(x, leak=0.2):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
def categorical_sample_logits(X):
# https://github.com/tensorflow/tensorflow/issues/456
U = tf.random_uniform(tf.shape(X))
return argmax(X - tf.log(-tf.log(U)), axis=1)
# ================================================================
# Inputs
# ================================================================
def is_placeholder(x):
return type(x) is tf.Tensor and len(x.op.inputs) == 0
class TfInput(object):
def __init__(self, name="(unnamed)"):
"""Generalized Tensorflow placeholder. The main differences are:
- possibly uses multiple placeholders internally and returns multiple values
- can apply light postprocessing to the value feed to placeholder.
"""
self.name = name
def get(self):
"""Return the tf variable(s) representing the possibly postprocessed value
of placeholder(s).
"""
raise NotImplemented()
def make_feed_dict(data):
"""Given data input it to the placeholder(s)."""
raise NotImplemented()
class PlacholderTfInput(TfInput):
def __init__(self, placeholder):
"""Wrapper for regular tensorflow placeholder."""
super().__init__(placeholder.name)
self._placeholder = placeholder
def get(self):
return self._placeholder
def make_feed_dict(self, data):
return {self._placeholder: data}
class BatchInput(PlacholderTfInput):
def __init__(self, shape, dtype=tf.float32, name=None):
"""Creates a placeholder for a batch of tensors of a given shape and dtype
Parameters
----------
shape: [int]
shape of a single elemenet of the batch
dtype: tf.dtype
number representation used for tensor contents
name: str
name of the underlying placeholder
"""
super().__init__(tf.placeholder(dtype, [None] + list(shape), name=name))
class Uint8Input(PlacholderTfInput):
def __init__(self, shape, name=None):
"""Takes input in uint8 format which is cast to float32 and divided by 255
before passing it to the model.
On GPU this ensures lower data transfer times.
Parameters
----------
shape: [int]
shape of the tensor.
name: str
name of the underlying placeholder
"""
super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name))
self._shape = shape
self._output = tf.cast(super().get(), tf.float32) / 255.0
def get(self):
return self._output
def ensure_tf_input(thing):
"""Takes either tf.placeholder of TfInput and outputs equivalent TfInput"""
if isinstance(thing, TfInput):
return thing
elif is_placeholder(thing):
return PlacholderTfInput(thing)
else:
raise ValueError("Must be a placeholder or TfInput")
# ================================================================
# Mathematical utils
# ================================================================
def huber_loss(x, delta=1.0):
"""Reference: https://en.wikipedia.org/wiki/Huber_loss"""
return tf.where(
tf.abs(x) < delta,
tf.square(x) * 0.5,
delta * (tf.abs(x) - 0.5 * delta)
)
# ================================================================
# Optimizer utils
# ================================================================
def minimize_and_clip(optimizer, objective, var_list, clip_val=10):
"""Minimized `objective` using `optimizer` w.r.t. variables in
`var_list` while ensure the norm of the gradients for each
variable is clipped to `clip_val`
"""
gradients = optimizer.compute_gradients(objective, var_list=var_list)
for i, (grad, var) in enumerate(gradients):
if grad is not None:
gradients[i] = (tf.clip_by_norm(grad, clip_val), var)
return optimizer.apply_gradients(gradients)
# ================================================================
# Global session
# ================================================================
def get_session():
"""Returns recently made Tensorflow session"""
return tf.get_default_session()
def make_session(num_cpu):
"""Returns a session that will use <num_cpu> CPU's only"""
tf_config = tf.ConfigProto(
inter_op_parallelism_threads=num_cpu,
intra_op_parallelism_threads=num_cpu)
return tf.Session(config=tf_config)
def single_threaded_session():
"""Returns a session which will only use a single CPU"""
return make_session(1)
ALREADY_INITIALIZED = set()
def initialize():
"""Initialize all the uninitialized variables in the global scope."""
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
get_session().run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
def eval(expr, feed_dict=None):
if feed_dict is None:
feed_dict = {}
return get_session().run(expr, feed_dict=feed_dict)
VALUE_SETTERS = collections.OrderedDict()
def set_value(v, val):
global VALUE_SETTERS
if v in VALUE_SETTERS:
set_op, set_endpoint = VALUE_SETTERS[v]
else:
set_endpoint = tf.placeholder(v.dtype)
set_op = v.assign(set_endpoint)
VALUE_SETTERS[v] = (set_op, set_endpoint)
get_session().run(set_op, feed_dict={set_endpoint: val})
# ================================================================
# Saving variables
# ================================================================
def load_state(fname):
saver = tf.train.Saver()
saver.restore(get_session(), fname)
def save_state(fname):
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(get_session(), fname)
# ================================================================
# Model components
# ================================================================
def normc_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None,
summary_tag=None):
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
# there are "num input feature maps * filter height * filter width"
# inputs to each hidden unit
fan_in = intprod(filter_shape[:3])
# each unit in the lower layer receives a gradient from:
# "num output feature maps * filter height * filter width" /
# pooling size
fan_out = intprod(filter_shape[:2]) * num_filters
# initialize weights with random weights
w_bound = np.sqrt(6. / (fan_in + fan_out))
w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
collections=collections)
b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(),
collections=collections)
if summary_tag is not None:
tf.summary.image(summary_tag,
tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]),
[2, 0, 1, 3]),
max_images=10)
return tf.nn.conv2d(x, w, stride_shape, pad) + b
def dense(x, size, name, weight_init=None, bias=True):
w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=weight_init)
ret = tf.matmul(x, w)
if bias:
b = tf.get_variable(name + "/b", [size], initializer=tf.zeros_initializer())
return ret + b
else:
return ret
def wndense(x, size, name, init_scale=1.0):
v = tf.get_variable(name + "/V", [int(x.get_shape()[1]), size],
initializer=tf.random_normal_initializer(0, 0.05))
g = tf.get_variable(name + "/g", [size], initializer=tf.constant_initializer(init_scale))
b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(0.0))
# use weight normalization (Salimans & Kingma, 2016)
x = tf.matmul(x, v)
scaler = g / tf.sqrt(sum(tf.square(v), axis=0, keepdims=True))
return tf.reshape(scaler, [1, size]) * x + tf.reshape(b, [1, size])
def densenobias(x, size, name, weight_init=None):
return dense(x, size, name, weight_init=weight_init, bias=False)
def dropout(x, pkeep, phase=None, mask=None):
mask = tf.floor(pkeep + tf.random_uniform(tf.shape(x))) if mask is None else mask
if phase is None:
return mask * x
else:
return switch(phase, mask * x, pkeep * x)
def batchnorm(x, name, phase, updates, gamma=0.96):
k = x.get_shape()[1]
runningmean = tf.get_variable(name + "/mean",
shape=[1, k],
initializer=tf.constant_initializer(0.0),
trainable=False)
runningvar = tf.get_variable(name + "/var",
shape=[1, k],
initializer=tf.constant_initializer(1e-4),
trainable=False)
testy = (x - runningmean) / tf.sqrt(runningvar)
mean_ = mean(x, axis=0, keepdims=True)
var_ = mean(tf.square(x), axis=0, keepdims=True)
std = tf.sqrt(var_)
trainy = (x - mean_) / std
updates.extend([
tf.assign(runningmean, runningmean * gamma + mean_ * (1 - gamma)),
tf.assign(runningvar, runningvar * gamma + var_ * (1 - gamma))
])
y = switch(phase, trainy, testy)
scaling = tf.get_variable(name + "/scaling",
shape=[1, k],
initializer=tf.constant_initializer(1.0),
trainable=True)
translation = tf.get_variable(name + "/translation",
shape=[1, k],
initializer=tf.constant_initializer(0.0),
trainable=True)
return y * scaling + translation
# ================================================================
# Theano-like Function
# ================================================================
def function(inputs, outputs, updates=None, givens=None):
"""Just like Theano function. Take a bunch of tensorflow placeholders and expersions
computed based on those placeholders and produces f(inputs) -> outputs. Function f takes
values to be feed to the inputs placeholders and produces the values of the experessions
in outputs.
Input values can be passed in the same order as inputs or can be provided as kwargs based
on placeholder name (passed to constructor or accessible via placeholder.op.name).
Example:
x = tf.placeholder(tf.int32, (), name="x")
y = tf.placeholder(tf.int32, (), name="y")
z = 3 * x + 2 * y
lin = function([x, y], z, givens={y: 0})
with single_threaded_session():
initialize()
assert lin(2) == 6
assert lin(x=3) == 9
assert lin(2, 2) == 10
assert lin(x=2, y=3) == 12
Parameters
----------
inputs: [tf.placeholder or TfInput]
list of input arguments
outputs: [tf.Variable] or tf.Variable
list of outputs or a single output to be returned from function. Returned
value will also have the same shape.
"""
if isinstance(outputs, list):
return _Function(inputs, outputs, updates, givens=givens)
elif isinstance(outputs, (dict, collections.OrderedDict)):
f = _Function(inputs, outputs.values(), updates, givens=givens)
return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs)))
else:
f = _Function(inputs, [outputs], updates, givens=givens)
return lambda *args, **kwargs: f(*args, **kwargs)[0]
class _Function(object):
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
for inpt in inputs:
if not issubclass(type(inpt), TfInput):
assert len(inpt.op.inputs) == 0, "inputs should all be placeholders of baselines.common.TfInput"
self.inputs = inputs
updates = updates or []
self.update_group = tf.group(*updates)
self.outputs_update = list(outputs) + [self.update_group]
self.givens = {} if givens is None else givens
self.check_nan = check_nan
def _feed_input(self, feed_dict, inpt, value):
if issubclass(type(inpt), TfInput):
feed_dict.update(inpt.make_feed_dict(value))
elif is_placeholder(inpt):
feed_dict[inpt] = value
def __call__(self, *args, **kwargs):
assert len(args) <= len(self.inputs), "Too many arguments provided"
feed_dict = {}
# Update the args
for inpt, value in zip(self.inputs, args):
self._feed_input(feed_dict, inpt, value)
# Update the kwargs
kwargs_passed_inpt_names = set()
for inpt in self.inputs[len(args):]:
inpt_name = inpt.name.split(':')[0]
inpt_name = inpt_name.split('/')[-1]
assert inpt_name not in kwargs_passed_inpt_names, \
"this function has two arguments with the same name \"{}\", so kwargs cannot be used.".format(inpt_name)
if inpt_name in kwargs:
kwargs_passed_inpt_names.add(inpt_name)
self._feed_input(feed_dict, inpt, kwargs.pop(inpt_name))
else:
assert inpt in self.givens, "Missing argument " + inpt_name
assert len(kwargs) == 0, "Function got extra arguments " + str(list(kwargs.keys()))
# Update feed dict with givens.
for inpt in self.givens:
feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
if self.check_nan:
if any(np.isnan(r).any() for r in results):
raise RuntimeError("Nan detected")
return results
def mem_friendly_function(nondata_inputs, data_inputs, outputs, batch_size):
if isinstance(outputs, list):
return _MemFriendlyFunction(nondata_inputs, data_inputs, outputs, batch_size)
else:
f = _MemFriendlyFunction(nondata_inputs, data_inputs, [outputs], batch_size)
return lambda *inputs: f(*inputs)[0]
class _MemFriendlyFunction(object):
def __init__(self, nondata_inputs, data_inputs, outputs, batch_size):
self.nondata_inputs = nondata_inputs
self.data_inputs = data_inputs
self.outputs = list(outputs)
self.batch_size = batch_size
def __call__(self, *inputvals):
assert len(inputvals) == len(self.nondata_inputs) + len(self.data_inputs)
nondata_vals = inputvals[0:len(self.nondata_inputs)]
data_vals = inputvals[len(self.nondata_inputs):]
feed_dict = dict(zip(self.nondata_inputs, nondata_vals))
n = data_vals[0].shape[0]
for v in data_vals[1:]:
assert v.shape[0] == n
for i_start in range(0, n, self.batch_size):
slice_vals = [v[i_start:builtins.min(i_start + self.batch_size, n)] for v in data_vals]
for (var, val) in zip(self.data_inputs, slice_vals):
feed_dict[var] = val
results = tf.get_default_session().run(self.outputs, feed_dict=feed_dict)
if i_start == 0:
sum_results = results
else:
for i in range(len(results)):
sum_results[i] = sum_results[i] + results[i]
for i in range(len(results)):
sum_results[i] = sum_results[i] / n
return sum_results
# ================================================================
# Modules
# ================================================================
class Module(object):
def __init__(self, name):
self.name = name
self.first_time = True
self.scope = None
self.cache = {}
def __call__(self, *args):
if args in self.cache:
print("(%s) retrieving value from cache" % (self.name,))
return self.cache[args]
with tf.variable_scope(self.name, reuse=not self.first_time):
scope = tf.get_variable_scope().name
if self.first_time:
self.scope = scope
print("(%s) running function for the first time" % (self.name,))
else:
assert self.scope == scope, "Tried calling function with a different scope"
print("(%s) running function on new inputs" % (self.name,))
self.first_time = False
out = self._call(*args)
self.cache[args] = out
return out
def _call(self, *args):
raise NotImplementedError
@property
def trainable_variables(self):
assert self.scope is not None, "need to call module once before getting variables"
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
@property
def variables(self):
assert self.scope is not None, "need to call module once before getting variables"
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
def module(name):
@functools.wraps
def wrapper(f):
class WrapperModule(Module):
def _call(self, *args):
return f(*args)
return WrapperModule(name)
return wrapper
# ================================================================
# Graph traversal
# ================================================================
VARIABLES = {}
def get_parents(node):
return node.op.inputs
def topsorted(outputs):
"""
Topological sort via non-recursive depth-first search
"""
assert isinstance(outputs, (list, tuple))
marks = {}
out = []
stack = [] # pylint: disable=W0621
# i: node
# jidx = number of children visited so far from that node
# marks: state of each node, which is one of
# 0: haven't visited
# 1: have visited, but not done visiting children
# 2: done visiting children
for x in outputs:
stack.append((x, 0))
while stack:
(i, jidx) = stack.pop()
if jidx == 0:
m = marks.get(i, 0)
if m == 0:
marks[i] = 1
elif m == 1:
raise ValueError("not a dag")
else:
continue
ps = get_parents(i)
if jidx == len(ps):
marks[i] = 2
out.append(i)
else:
stack.append((i, jidx + 1))
j = ps[jidx]
stack.append((j, 0))
return out
# ================================================================
# Flat vectors
# ================================================================
def var_shape(x):
out = x.get_shape().as_list()
assert all(isinstance(a, int) for a in out), \
"shape function assumes that shape is fully known"
return out
def numel(x):
return intprod(var_shape(x))
def intprod(x):
return int(np.prod(x))
def flatgrad(loss, var_list):
grads = tf.gradients(loss, var_list)
return tf.concat(axis=0, values=[
tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)])
for (v, grad) in zip(var_list, grads)
])
class SetFromFlat(object):
def __init__(self, var_list, dtype=tf.float32):
assigns = []
shapes = list(map(var_shape, var_list))
total_size = np.sum([intprod(shape) for shape in shapes])
self.theta = theta = tf.placeholder(dtype, [total_size])
start = 0
assigns = []
for (shape, v) in zip(shapes, var_list):
size = intprod(shape)
assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], shape)))
start += size
self.op = tf.group(*assigns)
def __call__(self, theta):
get_session().run(self.op, feed_dict={self.theta: theta})
class GetFlat(object):
def __init__(self, var_list):
self.op = tf.concat(axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list])
def __call__(self):
return get_session().run(self.op)
# ================================================================
# Misc
# ================================================================
def fancy_slice_2d(X, inds0, inds1):
"""
like numpy X[inds0, inds1]
XXX this implementation is bad
"""
inds0 = tf.cast(inds0, tf.int64)
inds1 = tf.cast(inds1, tf.int64)
shape = tf.cast(tf.shape(X), tf.int64)
ncols = shape[1]
Xflat = tf.reshape(X, [-1])
return tf.gather(Xflat, inds0 * ncols + inds1)
# ================================================================
# Scopes
# ================================================================
def scope_vars(scope, trainable_only=False):
"""
Get variables inside a scope
The scope can be specified as a string
Parameters
----------
scope: str or VariableScope
scope in which the variables reside.
trainable_only: bool
whether or not to return only the variables that were marked as trainable.
Returns
-------
vars: [tf.Variable]
list of variables in `scope`.
"""
return tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.VARIABLES,
scope=scope if isinstance(scope, str) else scope.name
)
def scope_name():
"""Returns the name of current scope as a string, e.g. deepq/q_func"""
return tf.get_variable_scope().name
def absolute_scope_name(relative_scope_name):
"""Appends parent scope name to `relative_scope_name`"""
return scope_name() + "/" + relative_scope_name
def lengths_to_mask(lengths_b, max_length):
"""
Turns a vector of lengths into a boolean mask
Args:
lengths_b: an integer vector of lengths
max_length: maximum length to fill the mask
Returns:
a boolean array of shape (batch_size, max_length)
row[i] consists of True repeated lengths_b[i] times, followed by False
"""
lengths_b = tf.convert_to_tensor(lengths_b)
assert lengths_b.get_shape().ndims == 1
mask_bt = tf.expand_dims(tf.range(max_length), 0) < tf.expand_dims(lengths_b, 1)
return mask_bt
def in_session(f):
@functools.wraps(f)
def newfunc(*args, **kwargs):
with tf.Session():
f(*args, **kwargs)
return newfunc
_PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape)
def get_placeholder(name, dtype, shape):
if name in _PLACEHOLDER_CACHE:
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
assert dtype1 == dtype and shape1 == shape
return out
else:
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
return out
def get_placeholder_cached(name):
return _PLACEHOLDER_CACHE[name][0]
def flattenallbut0(x):
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
def reset():
global _PLACEHOLDER_CACHE
global VARIABLES
_PLACEHOLDER_CACHE = {}
VARIABLES = {}
tf.reset_default_graph()

View File

@@ -0,0 +1,5 @@
from baselines.deepq import models # noqa
from baselines.deepq.build_graph import build_act, build_train # noqa
from baselines.deepq.simple import learn, load # noqa
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer # noqa

View File

@@ -0,0 +1,249 @@
"""Deep Q learning graph
The functions in this file can are used to create the following functions:
======= act ========
Function to chose an action given an observation
Parameters
----------
observation: object
Observation that can be feed into the output of make_obs_ph
stochastic: bool
if set to False all the actions are always deterministic (default False)
update_eps_ph: float
update epsilon a new value, if negative not update happens
(default: no update)
Returns
-------
Tensor of dtype tf.int64 and shape (BATCH_SIZE,) with an action to be performed for
every element of the batch.
======= train =======
Function that takes a transition (s,a,r,s') and optimizes Bellman equation's error:
td_error = Q(s,a) - (r + gamma * max_a' Q(s', a'))
loss = hauber_loss[td_error]
Parameters
----------
obs_t: object
a batch of observations
action: np.array
actions that were selected upon seeing obs_t.
dtype must be int32 and shape must be (batch_size,)
reward: np.array
immediate reward attained after executing those actions
dtype must be float32 and shape must be (batch_size,)
obs_tp1: object
observations that followed obs_t
done: np.array
1 if obs_t was the last observation in the episode and 0 otherwise
obs_tp1 gets ignored, but must be of the valid shape.
dtype must be float32 and shape must be (batch_size,)
weight: np.array
imporance weights for every element of the batch (gradient is multiplied
by the importance weight) dtype must be float32 and shape must be (batch_size,)
Returns
-------
td_error: np.array
a list of differences between Q(s,a) and the target in Bellman's equation.
dtype is float32 and shape is (batch_size,)
======= update_target ========
copy the parameters from optimized Q function to the target Q function.
In Q learning we actually optimize the following error:
Q(s,a) - (r + gamma * max_a' Q'(s', a'))
Where Q' is lagging behind Q to stablize the learning. For example for Atari
Q' is set to Q once every 10000 updates training steps.
"""
import tensorflow as tf
import baselines.common.tf_util as U
def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
"""Creates the act function:
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that take a name and creates a placeholder of input with that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of every action.
num_actions: int
number of actions.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the scope must be given.
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given obsercation.
` See the top of the file for details.
"""
with tf.variable_scope(scope, reuse=reuse):
observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0))
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(observations_ph.get())[0]
random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)
output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph],
outputs=output_actions,
givens={update_eps_ph: -1.0, stochastic_ph: True},
updates=[update_eps_expr])
return act
def build_train(make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None, gamma=1.0, double_q=True, scope="deepq", reuse=None):
"""Creates the act function:
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that take a name and creates a placeholder of input with that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of every action.
num_actions: int
number of actions
reuse: bool
whether or not to reuse the graph variables
optimizer: tf.train.Optimizer
optimizer to use for the Q-learning objective.
grad_norm_clipping: float or None
clip graident norms to this value. If None no clipping is performed.
gamma: float
discount rate.
double_q: bool
if true will use Double Q Learning (https://arxiv.org/abs/1509.06461).
In general it is a good idea to keep it enabled.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the scope must be given.
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given obsercation.
` See the top of the file for details.
train: (object, np.array, np.array, object, np.array, np.array) -> np.array
optimize the error in Bellman's equation.
` See the top of the file for details.
update_target: () -> ()
copy the parameters from optimized Q function to the target Q function.
` See the top of the file for details.
debug: {str: function}
a bunch of functions to print debug data like q_values.
"""
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse)
with tf.variable_scope(scope, reuse=reuse):
# set up placeholders
obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t"))
act_t_ph = tf.placeholder(tf.int32, [None], name="action")
rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1"))
done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")
# q network evaluation
q_t = q_func(obs_t_input.get(), num_actions, scope="q_func", reuse=True) # reuse parameters from act
q_func_vars = U.scope_vars(U.absolute_scope_name("q_func"))
# target q network evalution
q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func")
target_q_func_vars = U.scope_vars(U.absolute_scope_name("target_q_func"))
# q scores for actions which we know were selected in the given state.
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), 1)
# compute estimate of best possible value starting from state at t + 1
if double_q:
q_tp1_using_online_net = q_func(obs_tp1_input.get(), num_actions, scope="q_func", reuse=True)
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
q_tp1_best = tf.reduce_sum(q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), 1)
else:
q_tp1_best = tf.reduce_max(q_tp1, 1)
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked
# compute the error (potentially clipped)
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
errors = U.huber_loss(td_error)
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
# compute optimization op (potentially with gradient clipping)
if grad_norm_clipping is not None:
optimize_expr = U.minimize_and_clip(optimizer,
weighted_error,
var_list=q_func_vars,
clip_val=grad_norm_clipping)
else:
optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars)
# update_target_fn will be called periodically to copy Q network to target Q network
update_target_expr = []
for var, var_target in zip(sorted(q_func_vars, key=lambda v: v.name),
sorted(target_q_func_vars, key=lambda v: v.name)):
update_target_expr.append(var_target.assign(var))
update_target_expr = tf.group(*update_target_expr)
# Create callable functions
train = U.function(
inputs=[
obs_t_input,
act_t_ph,
rew_t_ph,
obs_tp1_input,
done_mask_ph,
importance_weights_ph
],
outputs=td_error,
updates=[optimize_expr]
)
update_target = U.function([], [], updates=[update_target_expr])
q_values = U.function([obs_t_input], q_t)
return act_f, train, update_target, {'q_values': q_values}

View File

View File

@@ -0,0 +1,51 @@
import argparse
import progressbar
from baselines.common.azure_utils import Container
def parse_args():
parser = argparse.ArgumentParser("Download a pretrained model from Azure.")
# Environment
parser.add_argument("--model-dir", type=str, default=None,
help="save model in this directory this directory. ")
parser.add_argument("--account-name", type=str, default="openaisciszymon",
help="account name for Azure Blob Storage")
parser.add_argument("--account-key", type=str, default=None,
help="account key for Azure Blob Storage")
parser.add_argument("--container", type=str, default="dqn-blogpost",
help="container name and blob name separated by colon serparated by colon")
parser.add_argument("--blob", type=str, default=None, help="blob with the model")
return parser.parse_args()
def main():
args = parse_args()
c = Container(account_name=args.account_name,
account_key=args.account_key,
container_name=args.container)
if args.blob is None:
print("Listing available models:")
print()
for blob in sorted(c.list(prefix="model-")):
print(blob)
else:
print("Downloading {} to {}...".format(args.blob, args.model_dir))
bar = None
def callback(current, total):
nonlocal bar
if bar is None:
bar = progressbar.ProgressBar(max_value=total)
bar.update(current)
assert c.exists(args.blob), "model {} does not exist".format(args.blob)
assert args.model_dir is not None
c.get(args.model_dir, args.blob, callback=callback)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,70 @@
import argparse
import gym
import os
import numpy as np
from gym.monitoring import VideoRecorder
import baselines.common.tf_util as U
from baselines import deepq
from baselines.common.misc_util import (
boolean_flag,
SimpleMonitor,
)
from baselines.common.atari_wrappers_deprecated import wrap_dqn
from baselines.deepq.experiments.atari.model import model, dueling_model
def parse_args():
parser = argparse.ArgumentParser("Run an already learned DQN model.")
# Environment
parser.add_argument("--env", type=str, required=True, help="name of the game")
parser.add_argument("--model-dir", type=str, default=None, help="load model from this directory. ")
parser.add_argument("--video", type=str, default=None, help="Path to mp4 file where the video of first episode will be recorded.")
boolean_flag(parser, "stochastic", default=True, help="whether or not to use stochastic actions according to models eps value")
boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
return parser.parse_args()
def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v3")
env = SimpleMonitor(env)
env = wrap_dqn(env)
return env
def play(env, act, stochastic, video_path):
num_episodes = 0
video_recorder = None
video_recorder = VideoRecorder(
env, video_path, enabled=video_path is not None)
obs = env.reset()
while True:
env.unwrapped.render()
video_recorder.capture_frame()
action = act(np.array(obs)[None], stochastic=stochastic)[0]
obs, rew, done, info = env.step(action)
if done:
obs = env.reset()
if len(info["rewards"]) > num_episodes:
if len(info["rewards"]) == 1 and video_recorder.enabled:
# save video of first episode
print("Saved video.")
video_recorder.close()
video_recorder.enabled = False
print(info["rewards"][-1])
num_episodes = len(info["rewards"])
if __name__ == '__main__':
with U.make_session(4) as sess:
args = parse_args()
env = make_env(args.env)
act = deepq.build_act(
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
q_func=dueling_model if args.dueling else model,
num_actions=env.action_space.n)
U.load_state(os.path.join(args.model_dir, "saved"))
play(env, act, args.stochastic, args.video)

View File

@@ -0,0 +1,43 @@
import tensorflow as tf
import tensorflow.contrib.layers as layers
def model(img_in, num_actions, scope, reuse=False):
"""As described in https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf"""
with tf.variable_scope(scope, reuse=reuse):
out = img_in
with tf.variable_scope("convnet"):
# original architecture
out = layers.convolution2d(out, num_outputs=32, kernel_size=8, stride=4, activation_fn=tf.nn.relu)
out = layers.convolution2d(out, num_outputs=64, kernel_size=4, stride=2, activation_fn=tf.nn.relu)
out = layers.convolution2d(out, num_outputs=64, kernel_size=3, stride=1, activation_fn=tf.nn.relu)
out = layers.flatten(out)
with tf.variable_scope("action_value"):
out = layers.fully_connected(out, num_outputs=512, activation_fn=tf.nn.relu)
out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
return out
def dueling_model(img_in, num_actions, scope, reuse=False):
"""As described in https://arxiv.org/abs/1511.06581"""
with tf.variable_scope(scope, reuse=reuse):
out = img_in
with tf.variable_scope("convnet"):
# original architecture
out = layers.convolution2d(out, num_outputs=32, kernel_size=8, stride=4, activation_fn=tf.nn.relu)
out = layers.convolution2d(out, num_outputs=64, kernel_size=4, stride=2, activation_fn=tf.nn.relu)
out = layers.convolution2d(out, num_outputs=64, kernel_size=3, stride=1, activation_fn=tf.nn.relu)
out = layers.flatten(out)
with tf.variable_scope("state_value"):
state_hidden = layers.fully_connected(out, num_outputs=512, activation_fn=tf.nn.relu)
state_score = layers.fully_connected(state_hidden, num_outputs=1, activation_fn=None)
with tf.variable_scope("action_value"):
actions_hidden = layers.fully_connected(out, num_outputs=512, activation_fn=tf.nn.relu)
action_scores = layers.fully_connected(actions_hidden, num_outputs=num_actions, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores = action_scores - tf.expand_dims(action_scores_mean, 1)
return state_score + action_scores

View File

@@ -0,0 +1,229 @@
import argparse
import gym
import numpy as np
import os
import tensorflow as tf
import tempfile
import time
import baselines.common.tf_util as U
from baselines import logger
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from baselines.common.misc_util import (
boolean_flag,
pickle_load,
pretty_eta,
relatively_safe_pickle_dump,
set_global_seeds,
RunningAvg,
SimpleMonitor
)
from baselines.common.schedules import LinearSchedule, PiecewiseSchedule
# when updating this to non-deperecated ones, it is important to
# copy over LazyFrames
from baselines.common.atari_wrappers_deprecated import wrap_dqn
from baselines.common.azure_utils import Container
from .model import model, dueling_model
def parse_args():
parser = argparse.ArgumentParser("DQN experiments for Atari games")
# Environment
parser.add_argument("--env", type=str, default="Pong", help="name of the game")
parser.add_argument("--seed", type=int, default=42, help="which seed to use")
# Core DQN parameters
parser.add_argument("--replay-buffer-size", type=int, default=int(1e6), help="replay buffer size")
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for Adam optimizer")
parser.add_argument("--num-steps", type=int, default=int(2e8), help="total number of steps to run the environment for")
parser.add_argument("--batch-size", type=int, default=32, help="number of transitions to optimize at the same time")
parser.add_argument("--learning-freq", type=int, default=4, help="number of iterations between every optimization step")
parser.add_argument("--target-update-freq", type=int, default=40000, help="number of iterations between every target network update")
# Bells and whistles
boolean_flag(parser, "double-q", default=True, help="whether or not to use double q learning")
boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
boolean_flag(parser, "prioritized", default=False, help="whether or not to use prioritized replay buffer")
parser.add_argument("--prioritized-alpha", type=float, default=0.6, help="alpha parameter for prioritized replay buffer")
parser.add_argument("--prioritized-beta0", type=float, default=0.4, help="initial value of beta parameters for prioritized replay")
parser.add_argument("--prioritized-eps", type=float, default=1e-6, help="eps parameter for prioritized replay buffer")
# Checkpointing
parser.add_argument("--save-dir", type=str, default=None, help="directory in which training state and model should be saved.")
parser.add_argument("--save-azure-container", type=str, default=None,
help="It present data will saved/loaded from Azure. Should be in format ACCOUNT_NAME:ACCOUNT_KEY:CONTAINER")
parser.add_argument("--save-freq", type=int, default=1e6, help="save model once every time this many iterations are completed")
boolean_flag(parser, "load-on-start", default=True, help="if true and model was previously saved then training will be resumed")
return parser.parse_args()
def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v3")
monitored_env = SimpleMonitor(env) # puts rewards and number of steps in info, before environment is wrapped
env = wrap_dqn(monitored_env) # applies a bunch of modification to simplify the observation space (downsample, make b/w)
return env, monitored_env
def maybe_save_model(savedir, container, state):
"""This function checkpoints the model and state of the training algorithm."""
if savedir is None:
return
start_time = time.time()
model_dir = "model-{}".format(state["num_iters"])
U.save_state(os.path.join(savedir, model_dir, "saved"))
if container is not None:
container.put(os.path.join(savedir, model_dir), model_dir)
relatively_safe_pickle_dump(state, os.path.join(savedir, 'training_state.pkl.zip'), compression=True)
if container is not None:
container.put(os.path.join(savedir, 'training_state.pkl.zip'), 'training_state.pkl.zip')
relatively_safe_pickle_dump(state["monitor_state"], os.path.join(savedir, 'monitor_state.pkl'))
if container is not None:
container.put(os.path.join(savedir, 'monitor_state.pkl'), 'monitor_state.pkl')
logger.log("Saved model in {} seconds\n".format(time.time() - start_time))
def maybe_load_model(savedir, container):
"""Load model if present at the specified path."""
if savedir is None:
return
state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
if container is not None:
logger.log("Attempting to download model from Azure")
found_model = container.get(savedir, 'training_state.pkl.zip')
else:
found_model = os.path.exists(state_path)
if found_model:
state = pickle_load(state_path, compression=True)
model_dir = "model-{}".format(state["num_iters"])
if container is not None:
container.get(savedir, model_dir)
U.load_state(os.path.join(savedir, model_dir, "saved"))
logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
return state
if __name__ == '__main__':
args = parse_args()
# Parse savedir and azure container.
savedir = args.save_dir
if args.save_azure_container is not None:
account_name, account_key, container_name = args.save_azure_container.split(":")
container = Container(account_name=account_name,
account_key=account_key,
container_name=container_name,
maybe_create=True)
if savedir is None:
# Careful! This will not get cleaned up. Docker spoils the developers.
savedir = tempfile.TemporaryDirectory().name
else:
container = None
# Create and seed the env.
env, monitored_env = make_env(args.env)
if args.seed > 0:
set_global_seeds(args.seed)
env.unwrapped.seed(args.seed)
with U.make_session(4) as sess:
# Create training graph and replay buffer
act, train, update_target, debug = deepq.build_train(
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
q_func=dueling_model if args.dueling else model,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=args.lr, epsilon=1e-4),
gamma=0.99,
grad_norm_clipping=10,
double_q=args.double_q
)
approximate_num_iters = args.num_steps / 4
exploration = PiecewiseSchedule([
(0, 1.0),
(approximate_num_iters / 50, 0.1),
(approximate_num_iters / 5, 0.01)
], outside_value=0.01)
if args.prioritized:
replay_buffer = PrioritizedReplayBuffer(args.replay_buffer_size, args.prioritized_alpha)
beta_schedule = LinearSchedule(approximate_num_iters, initial_p=args.prioritized_beta0, final_p=1.0)
else:
replay_buffer = ReplayBuffer(args.replay_buffer_size)
U.initialize()
update_target()
num_iters = 0
# Load the model
state = maybe_load_model(savedir, container)
if state is not None:
num_iters, replay_buffer = state["num_iters"], state["replay_buffer"],
monitored_env.set_state(state["monitor_state"])
start_time, start_steps = None, None
steps_per_iter = RunningAvg(0.999)
iteration_time_est = RunningAvg(0.999)
obs = env.reset()
# Main trianing loop
while True:
num_iters += 1
# Take action and store transition in the replay buffer.
action = act(np.array(obs)[None], update_eps=exploration.value(num_iters))[0]
new_obs, rew, done, info = env.step(action)
replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
if done:
obs = env.reset()
if (num_iters > max(5 * args.batch_size, args.replay_buffer_size // 20) and
num_iters % args.learning_freq == 0):
# Sample a bunch of transitions from replay buffer
if args.prioritized:
experience = replay_buffer.sample(args.batch_size, beta=beta_schedule.value(num_iters))
(obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
else:
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(args.batch_size)
weights = np.ones_like(rewards)
# Minimize the error in Bellman's equation and compute TD-error
td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights)
# Update the priorities in the replay buffer
if args.prioritized:
new_priorities = np.abs(td_errors) + args.prioritized_eps
replay_buffer.update_priorities(batch_idxes, new_priorities)
# Update target network.
if num_iters % args.target_update_freq == 0:
update_target()
if start_time is not None:
steps_per_iter.update(info['steps'] - start_steps)
iteration_time_est.update(time.time() - start_time)
start_time, start_steps = time.time(), info["steps"]
# Save the model and training state.
if num_iters > 0 and (num_iters % args.save_freq == 0 or info["steps"] > args.num_steps):
maybe_save_model(savedir, container, {
'replay_buffer': replay_buffer,
'num_iters': num_iters,
'monitor_state': monitored_env.get_state()
})
if info["steps"] > args.num_steps:
break
if done:
steps_left = args.num_steps - info["steps"]
completion = np.round(info["steps"] / args.num_steps, 1)
logger.record_tabular("% completion", completion)
logger.record_tabular("steps", info["steps"])
logger.record_tabular("iters", num_iters)
logger.record_tabular("episodes", len(info["rewards"]))
logger.record_tabular("reward (100 epi mean)", np.mean(info["rewards"][-100:]))
logger.record_tabular("exploration", exploration.value(num_iters))
if args.prioritized:
logger.record_tabular("max priority", replay_buffer._max_priority)
fps_estimate = (float(steps_per_iter) / (float(iteration_time_est) + 1e-6)
if steps_per_iter._value is not None else "calculating...")
logger.dump_tabular()
logger.log()
logger.log("ETA: " + pretty_eta(int(steps_left / fps_estimate)))
logger.log()

View File

@@ -0,0 +1,81 @@
import argparse
import gym
import numpy as np
import os
import baselines.common.tf_util as U
from baselines import deepq
from baselines.common.misc_util import get_wrapper_by_name, SimpleMonitor, boolean_flag, set_global_seeds
from baselines.common.atari_wrappers_deprecated import wrap_dqn
from baselines.deepq.experiments.atari.model import model, dueling_model
def make_env(game_name):
env = gym.make(game_name + "NoFrameskip-v3")
env_monitored = SimpleMonitor(env)
env = wrap_dqn(env_monitored)
return env_monitored, env
def parse_args():
parser = argparse.ArgumentParser("Evaluate an already learned DQN model.")
# Environment
parser.add_argument("--env", type=str, required=True, help="name of the game")
parser.add_argument("--model-dir", type=str, default=None, help="load model from this directory. ")
boolean_flag(parser, "stochastic", default=True, help="whether or not to use stochastic actions according to models eps value")
boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
return parser.parse_args()
def wang2015_eval(game_name, act, stochastic):
print("==================== wang2015 evaluation ====================")
episode_rewards = []
for num_noops in range(1, 31):
env_monitored, eval_env = make_env(game_name)
eval_env.unwrapped.seed(1)
get_wrapper_by_name(eval_env, "NoopResetEnv").override_num_noops = num_noops
eval_episode_steps = 0
done = True
while True:
if done:
obs = eval_env.reset()
eval_episode_steps += 1
action = act(np.array(obs)[None], stochastic=stochastic)[0]
obs, reward, done, info = eval_env.step(action)
if done:
obs = eval_env.reset()
if len(info["rewards"]) > 0:
episode_rewards.append(info["rewards"][0])
break
if info["steps"] > 108000: # 5 minutes of gameplay
episode_rewards.append(env_monitored._current_reward)
break
print("Num steps in episode {} was {} yielding {} reward".format(
num_noops, eval_episode_steps, episode_rewards[-1]), flush=True)
print("Evaluation results: " + str(np.mean(episode_rewards)))
print("=============================================================")
return np.mean(episode_rewards)
def main():
set_global_seeds(1)
args = parse_args()
with U.make_session(4) as sess: # noqa
_, env = make_env(args.env)
act = deepq.build_act(
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
q_func=dueling_model if args.dueling else model,
num_actions=env.action_space.n)
U.load_state(os.path.join(args.model_dir, "saved"))
wang2015_eval(args.env, act, stochastic=args.stochastic)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,78 @@
import gym
import itertools
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
import baselines.common.tf_util as U
from baselines import logger
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer
from baselines.common.schedules import LinearSchedule
def model(inpt, num_actions, scope, reuse=False):
"""This model takes as input an observation and returns values of all actions."""
with tf.variable_scope(scope, reuse=reuse):
out = inpt
out = layers.fully_connected(out, num_outputs=64, activation_fn=tf.nn.tanh)
out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
return out
if __name__ == '__main__':
with U.make_session(8):
# Create the environment
env = gym.make("CartPole-v0")
# Create all the functions necessary to train the model
act, train, update_target, debug = deepq.build_train(
make_obs_ph=lambda name: U.BatchInput(env.observation_space.shape, name=name),
q_func=model,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=5e-4),
)
# Create the replay buffer
replay_buffer = ReplayBuffer(50000)
# Create the schedule for exploration starting from 1 (every action is random) down to
# 0.02 (98% of actions are selected according to values predicted by the model).
exploration = LinearSchedule(schedule_timesteps=10000, initial_p=1.0, final_p=0.02)
# Initialize the parameters and copy them to the target network.
U.initialize()
update_target()
episode_rewards = [0.0]
obs = env.reset()
for t in itertools.count():
# Take action and update exploration to the newest value
action = act(obs[None], update_eps=exploration.value(t))[0]
new_obs, rew, done, _ = env.step(action)
# Store transition in the replay buffer.
replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
episode_rewards[-1] += rew
if done:
obs = env.reset()
episode_rewards.append(0)
is_solved = t > 100 and np.mean(episode_rewards[-101:-1]) >= 200
if is_solved:
# Show off the result
env.render()
else:
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
if t > 1000:
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(32)
train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
# Update target network periodically.
if t % 1000 == 0:
update_target()
if done and len(episode_rewards) % 10 == 0:
logger.record_tabular("steps", t)
logger.record_tabular("episodes", len(episode_rewards))
logger.record_tabular("mean episode reward", round(np.mean(episode_rewards[-101:-1]), 1))
logger.record_tabular("% time spent exploring", int(100 * exploration.value(t)))
logger.dump_tabular()

View File

@@ -0,0 +1,21 @@
import gym
from baselines import deepq
def main():
env = gym.make("CartPole-v0")
act = deepq.load("cartpole_model.pkl")
while True:
obs, done = env.reset(), False
episode_rew = 0
while not done:
env.render()
obs, rew, done, _ = env.step(act(obs[None])[0])
episode_rew += rew
print("Episode reward", episode_rew)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,23 @@
import gym
from baselines import deepq
from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFrame
def main():
env = gym.make("PongNoFrameskip-v3")
env = ScaledFloatFrame(wrap_dqn(env))
act = deepq.load("pong_model.pkl")
while True:
obs, done = env.reset(), False
episode_rew = 0
while not done:
env.render()
obs, rew, done, _ = env.step(act(obs[None])[0])
episode_rew += rew
print("Episode reward", episode_rew)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,31 @@
import gym
from baselines import deepq
def callback(lcl, glb):
# stop training if reward exceeds 199
is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
return is_solved
def main():
env = gym.make("CartPole-v0")
model = deepq.models.mlp([64])
act = deepq.learn(
env,
q_func=model,
lr=1e-3,
max_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.02,
print_freq=10,
callback=callback
)
print("Saving model to cartpole_model.pkl")
act.save("cartpole_model.pkl")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,34 @@
import gym
from baselines import deepq
from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFrame
def main():
env = gym.make("PongNoFrameskip-v3")
env = ScaledFloatFrame(wrap_dqn(env))
model = deepq.models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True
)
act = deepq.learn(
env,
q_func=model,
lr=1e-4,
max_timesteps=2000000,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True
)
act.save("pong_model.pkl")
env.close()
if __name__ == '__main__':
main()

82
baselines/deepq/models.py Normal file
View File

@@ -0,0 +1,82 @@
import tensorflow as tf
import tensorflow.contrib.layers as layers
def _mlp(hiddens, inpt, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
out = inpt
for hidden in hiddens:
out = layers.fully_connected(out, num_outputs=hidden, activation_fn=tf.nn.relu)
out = layers.fully_connected(out, num_outputs=num_actions, activation_fn=None)
return out
def mlp(hiddens=[]):
"""This model takes as input an observation and returns values of all actions.
Parameters
----------
hiddens: [int]
list of sizes of hidden layers
Returns
-------
q_func: function
q_function for DQN algorithm.
"""
return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs)
def _cnn_to_mlp(convs, hiddens, dueling, inpt, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
out = inpt
with tf.variable_scope("convnet"):
for num_outputs, kernel_size, stride in convs:
out = layers.convolution2d(out,
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
activation_fn=tf.nn.relu)
out = layers.flatten(out)
with tf.variable_scope("action_value"):
action_out = out
for hidden in hiddens:
action_out = layers.fully_connected(action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
action_scores = layers.fully_connected(action_out, num_outputs=num_actions, activation_fn=None)
if dueling:
with tf.variable_scope("state_value"):
state_out = out
for hidden in hiddens:
state_out = layers.fully_connected(state_out, num_outputs=hidden, activation_fn=tf.nn.relu)
state_score = layers.fully_connected(state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, 1)
return state_score + action_scores_centered
else:
return action_scores
return out
def cnn_to_mlp(convs, hiddens, dueling=False):
"""This model takes as input an observation and returns values of all actions.
Parameters
----------
convs: [(int, int int)]
list of convolutional layers in form of
(num_outputs, kernel_size, stride)
hiddens: [int]
list of sizes of hidden layers
dueling: bool
if true double the output MLP to compute a baseline
for action scores
Returns
-------
q_func: function
q_function for DQN algorithm.
"""
return lambda *args, **kwargs: _cnn_to_mlp(convs, hiddens, dueling, *args, **kwargs)

View File

@@ -0,0 +1,190 @@
import numpy as np
import random
from baselines.common.segment_tree import SumSegmentTree, MinSegmentTree
class ReplayBuffer(object):
def __init__(self, size):
"""Create Prioritized Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
"""
self._storage = []
self._maxsize = size
self._next_idx = 0
def __len__(self):
return len(self._storage)
def add(self, obs_t, action, reward, obs_tp1, done):
data = (obs_t, action, reward, obs_tp1, done)
if self._next_idx >= len(self._storage):
self._storage.append(data)
else:
self._storage[self._next_idx] = data
self._next_idx = (self._next_idx + 1) % self._maxsize
def _encode_sample(self, idxes):
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
for i in idxes:
data = self._storage[i]
obs_t, action, reward, obs_tp1, done = data
obses_t.append(np.array(obs_t, copy=False))
actions.append(np.array(action, copy=False))
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones)
def sample(self, batch_size):
"""Sample a batch of experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
"""
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
return self._encode_sample(idxes)
class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, size, alpha):
"""Create Prioritized Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
alpha: float
how much prioritization is used
(0 - no prioritization, 1 - full prioritization)
See Also
--------
ReplayBuffer.__init__
"""
super(PrioritizedReplayBuffer, self).__init__(size)
assert alpha > 0
self._alpha = alpha
it_capacity = 1
while it_capacity < size:
it_capacity *= 2
self._it_sum = SumSegmentTree(it_capacity)
self._it_min = MinSegmentTree(it_capacity)
self._max_priority = 1.0
def add(self, *args, **kwargs):
"""See ReplayBuffer.store_effect"""
idx = self._next_idx
super().add(*args, **kwargs)
self._it_sum[idx] = self._max_priority ** self._alpha
self._it_min[idx] = self._max_priority ** self._alpha
def _sample_proportional(self, batch_size):
res = []
for _ in range(batch_size):
# TODO(szymon): should we ensure no repeats?
mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
idx = self._it_sum.find_prefixsum_idx(mass)
res.append(idx)
return res
def sample(self, batch_size, beta):
"""Sample a batch of experiences.
compared to ReplayBuffer.sample
it also returns importance weights and idxes
of sampled experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
beta: float
To what degree to use importance weights
(0 - no corrections, 1 - full correction)
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
weights: np.array
Array of shape (batch_size,) and dtype np.float32
denoting importance weight of each sampled transition
idxes: np.array
Array of shape (batch_size,) and dtype np.int32
idexes in buffer of sampled experiences
"""
assert beta > 0
idxes = self._sample_proportional(batch_size)
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage)) ** (-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage)) ** (-beta)
weights.append(weight / max_weight)
weights = np.array(weights)
encoded_sample = self._encode_sample(idxes)
return tuple(list(encoded_sample) + [weights, idxes])
def update_priorities(self, idxes, priorities):
"""Update priorities of sampled transitions.
sets priority of transition at index idxes[i] in buffer
to priorities[i].
Parameters
----------
idxes: [int]
List of idxes of sampled transitions
priorities: [float]
List of updated priorities corresponding to
transitions at the sampled idxes denoted by
variable `idxes`.
"""
assert len(idxes) == len(priorities)
for idx, priority in zip(idxes, priorities):
assert priority > 0
assert 0 <= idx < len(self._storage)
self._it_sum[idx] = priority ** self._alpha
self._it_min[idx] = priority ** self._alpha
self._max_priority = max(self._max_priority, priority)

267
baselines/deepq/simple.py Normal file
View File

@@ -0,0 +1,267 @@
import numpy as np
import os
import dill
import tempfile
import tensorflow as tf
import zipfile
import baselines.common.tf_util as U
from baselines import logger
from baselines.common.schedules import LinearSchedule
from baselines import deepq
from baselines.deepq.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
class ActWrapper(object):
def __init__(self, act, act_params):
self._act = act
self._act_params = act_params
@staticmethod
def load(path, num_cpu=16):
with open(path, "rb") as f:
model_data, act_params = dill.load(f)
act = deepq.build_act(**act_params)
sess = U.make_session(num_cpu=num_cpu)
sess.__enter__()
with tempfile.TemporaryDirectory() as td:
arc_path = os.path.join(td, "packed.zip")
with open(arc_path, "wb") as f:
f.write(model_data)
zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
U.load_state(os.path.join(td, "model"))
return ActWrapper(act, act_params)
def __call__(self, *args, **kwargs):
return self._act(*args, **kwargs)
def save(self, path):
"""Save model to a pickle located at `path`"""
with tempfile.TemporaryDirectory() as td:
U.save_state(os.path.join(td, "model"))
arc_name = os.path.join(td, "packed.zip")
with zipfile.ZipFile(arc_name, 'w') as zipf:
for root, dirs, files in os.walk(td):
for fname in files:
file_path = os.path.join(root, fname)
if file_path != arc_name:
zipf.write(file_path, os.path.relpath(file_path, td))
with open(arc_name, "rb") as f:
model_data = f.read()
with open(path, "wb") as f:
dill.dump((model_data, self._act_params), f)
def load(path, num_cpu=16):
"""Load act function that was returned by learn function.
Parameters
----------
path: str
path to the act function pickle
num_cpu: int
number of cpus to use for executing the policy
Returns
-------
act: ActWrapper
function that takes a batch of obsercations
and returns actions.
"""
return ActWrapper.load(path, num_cpu=num_cpu)
def learn(env,
q_func,
lr=5e-4,
max_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.02,
train_freq=1,
batch_size=32,
print_freq=1,
checkpoint_freq=10000,
learning_starts=1000,
gamma=1.0,
target_network_update_freq=500,
prioritized_replay=False,
prioritized_replay_alpha=0.6,
prioritized_replay_beta0=0.4,
prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6,
num_cpu=16,
callback=None):
"""Train a deepq model.
Parameters
-------
env : gym.Env
environment to train on
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of every action.
lr: float
learning rate for adam optimizer
max_timesteps: int
number of env steps to optimizer for
buffer_size: int
size of the replay buffer
exploration_fraction: float
fraction of entire training period over which the exploration rate is annealed
exploration_final_eps: float
final value of random action probability
train_freq: int
update the model every `train_freq` steps.
set to None to disable printing
batch_size: int
size of a batched sampled from replay buffer for training
print_freq: int
how often to print out training progress
checkpoint_freq: int
how often to save the model. This is so that the best version is restored
at the end of the training. If you do not wish to restore the best version at
the end of the training set this variable to None.
learning_starts: int
how many steps of the model to collect transitions for before learning starts
gamma: float
discount factor
target_network_update_freq: int
update the target network every `target_network_update_freq` steps.
prioritized_replay: True
if True prioritized replay buffer will be used.
prioritized_replay_alpha: float
alpha parameter for prioritized replay buffer
prioritized_replay_beta0: float
initial value of beta for prioritized replay buffer
prioritized_replay_beta_iters: int
number of iterations over which beta will be annealed from initial value
to 1.0. If set to None equals to max_timesteps.
prioritized_replay_eps: float
epsilon to add to the TD errors when updating priorities.
num_cpu: int
number of cpus to use for training
callback: (locals, globals) -> None
function called at every steps with state of the algorithm.
If callback returns true training stops.
Returns
-------
act: ActWrapper
Wrapper over act function. Adds ability to save it and load it.
See header of baselines/deepq/categorical.py for details on the act function.
"""
# Create all the functions necessary to train the model
sess = U.make_session(num_cpu=num_cpu)
sess.__enter__()
def make_obs_ph(name):
return U.BatchInput(env.observation_space.shape, name=name)
act, train, update_target, debug = deepq.build_train(
make_obs_ph=make_obs_ph,
q_func=q_func,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=lr),
gamma=gamma,
grad_norm_clipping=10
)
act_params = {
'make_obs_ph': make_obs_ph,
'q_func': q_func,
'num_actions': env.action_space.n,
}
# Create the replay buffer
if prioritized_replay:
replay_buffer = PrioritizedReplayBuffer(buffer_size, alpha=prioritized_replay_alpha)
if prioritized_replay_beta_iters is None:
prioritized_replay_beta_iters = max_timesteps
beta_schedule = LinearSchedule(prioritized_replay_beta_iters,
initial_p=prioritized_replay_beta0,
final_p=1.0)
else:
replay_buffer = ReplayBuffer(buffer_size)
beta_schedule = None
# Create the schedule for exploration starting from 1.
exploration = LinearSchedule(schedule_timesteps=int(exploration_fraction * max_timesteps),
initial_p=1.0,
final_p=exploration_final_eps)
# Initialize the parameters and copy them to the target network.
U.initialize()
update_target()
episode_rewards = [0.0]
saved_mean_reward = None
obs = env.reset()
with tempfile.TemporaryDirectory() as td:
model_saved = False
model_file = os.path.join(td, "model")
for t in range(max_timesteps):
if callback is not None:
if callback(locals(), globals()):
break
# Take action and update exploration to the newest value
action = act(np.array(obs)[None], update_eps=exploration.value(t))[0]
new_obs, rew, done, _ = env.step(action)
# Store transition in the replay buffer.
replay_buffer.add(obs, action, rew, new_obs, float(done))
obs = new_obs
episode_rewards[-1] += rew
if done:
obs = env.reset()
episode_rewards.append(0)
if t > learning_starts and t % train_freq == 0:
# Minimize the error in Bellman's equation on a batch sampled from replay buffer.
if prioritized_replay:
experience = replay_buffer.sample(batch_size, beta=beta_schedule.value(t))
(obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience
else:
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size)
weights, batch_idxes = np.ones_like(rewards), None
td_errors = train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
if prioritized_replay:
new_priorities = np.abs(td_errors) + prioritized_replay_eps
replay_buffer.update_priorities(batch_idxes, new_priorities)
if t > learning_starts and t % target_network_update_freq == 0:
# Update target network periodically.
update_target()
mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 1)
num_episodes = len(episode_rewards)
if done and print_freq is not None and len(episode_rewards) % print_freq == 0:
logger.record_tabular("steps", t)
logger.record_tabular("episodes", num_episodes)
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
logger.record_tabular("% time spent exploring", int(100 * exploration.value(t)))
logger.dump_tabular()
if (checkpoint_freq is not None and t > learning_starts and
num_episodes > 100 and t % checkpoint_freq == 0):
if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward:
if print_freq is not None:
logger.log("Saving model due to mean reward increase: {} -> {}".format(
saved_mean_reward, mean_100ep_reward))
U.save_state(model_file)
model_saved = True
saved_mean_reward = mean_100ep_reward
if model_saved:
if print_freq is not None:
logger.log("Restored model with mean reward: {}".format(saved_mean_reward))
U.load_state(model_file)
return ActWrapper(act, act_params)

292
baselines/logger.py Normal file
View File

@@ -0,0 +1,292 @@
"""
See README.md for a description of the logging API.
OFF state corresponds to having Logger.CURRENT == Logger.DEFAULT
ON state is otherwise
"""
from collections import OrderedDict
import os
import sys
import shutil
import os.path as osp
import json
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json']
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50
class OutputFormat(object):
def writekvs(self, kvs):
"""
Write key-value pairs
"""
raise NotImplementedError
def writeseq(self, args):
"""
Write a sequence of other data (e.g. a logging message)
"""
pass
def close(self):
return
class HumanOutputFormat(OutputFormat):
def __init__(self, file):
self.file = file
def writekvs(self, kvs):
# Create strings for printing
key2str = OrderedDict()
for (key, val) in kvs.items():
valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
key2str[self._truncate(key)] = self._truncate(valstr)
# Find max widths
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write out the data
dashes = '-' * (keywidth + valwidth + 7)
lines = [dashes]
for (key, val) in key2str.items():
lines.append('| %s%s | %s%s |' % (
key,
' ' * (keywidth - len(key)),
val,
' ' * (valwidth - len(val)),
))
lines.append(dashes)
self.file.write('\n'.join(lines) + '\n')
# Flush the output to the file
self.file.flush()
def _truncate(self, s):
return s[:20] + '...' if len(s) > 23 else s
def writeseq(self, args):
for arg in args:
self.file.write(arg)
self.file.write('\n')
self.file.flush()
class JSONOutputFormat(OutputFormat):
def __init__(self, file):
self.file = file
def writekvs(self, kvs):
for k, v in kvs.items():
if hasattr(v, 'dtype'):
v = v.tolist()
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + '\n')
self.file.flush()
def make_output_format(format, ev_dir):
os.makedirs(ev_dir, exist_ok=True)
if format == 'stdout':
return HumanOutputFormat(sys.stdout)
elif format == 'log':
log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
return HumanOutputFormat(log_file)
elif format == 'json':
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
return JSONOutputFormat(json_file)
else:
raise ValueError('Unknown format specified: %s' % (format,))
# ================================================================
# API
# ================================================================
def logkv(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
"""
Logger.CURRENT.logkv(key, val)
def dumpkvs():
"""
Write all of the diagnostics from the current iteration
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
"""
Logger.CURRENT.dumpkvs()
# for backwards compatibility
record_tabular = logkv
dump_tabular = dumpkvs
def log(*args, level=INFO):
"""
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
"""
Logger.CURRENT.log(*args, level=level)
def debug(*args):
log(*args, level=DEBUG)
def info(*args):
log(*args, level=INFO)
def warn(*args):
log(*args, level=WARN)
def error(*args):
log(*args, level=ERROR)
def set_level(level):
"""
Set logging threshold on current logger.
"""
Logger.CURRENT.set_level(level)
def get_dir():
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
"""
return Logger.CURRENT.get_dir()
def get_expt_dir():
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),))
return get_dir()
# ================================================================
# Backend
# ================================================================
class Logger(object):
DEFAULT = None # A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output files
CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir, output_formats):
self.name2val = OrderedDict() # values this iteration
self.level = INFO
self.dir = dir
self.output_formats = output_formats
# Logging API, forwarded
# ----------------------------------------
def logkv(self, key, val):
self.name2val[key] = val
def dumpkvs(self):
for fmt in self.output_formats:
fmt.writekvs(self.name2val)
self.name2val.clear()
def log(self, *args, level=INFO):
if self.level <= level:
self._do_log(args)
# Configuration
# ----------------------------------------
def set_level(self, level):
self.level = level
def get_dir(self):
return self.dir
def close(self):
for fmt in self.output_formats:
fmt.close()
# Misc
# ----------------------------------------
def _do_log(self, args):
for fmt in self.output_formats:
fmt.writeseq(args)
# ================================================================
Logger.DEFAULT = Logger(output_formats=[HumanOutputFormat(sys.stdout)], dir=None)
Logger.CURRENT = Logger.DEFAULT
class session(object):
"""
Context manager that sets up the loggers for an experiment.
"""
CURRENT = None # Set to a LoggerContext object using enter/exit or context manager
def __init__(self, dir, format_strs=None):
self.dir = dir
if format_strs is None:
format_strs = LOG_OUTPUT_FORMATS
output_formats = [make_output_format(f, dir) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
def __enter__(self):
os.makedirs(self.evaluation_dir(), exist_ok=True)
output_formats = [make_output_format(f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
def __exit__(self, *args):
Logger.CURRENT.close()
Logger.CURRENT = Logger.DEFAULT
def evaluation_dir(self):
return self.dir
# ================================================================
def _demo():
info("hi")
debug("shouldn't appear")
set_level(DEBUG)
debug("should appear")
dir = "/tmp/testlogging"
if os.path.exists(dir):
shutil.rmtree(dir)
with session(dir=dir):
record_tabular("a", 3)
record_tabular("b", 2.5)
dump_tabular()
record_tabular("b", -2.5)
record_tabular("a", 5.5)
dump_tabular()
info("^^^ should see a = 5.5")
record_tabular("b", -2.5)
dump_tabular()
record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
dump_tabular()
if __name__ == "__main__":
_demo()

BIN
data/cartpole.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 327 KiB

BIN
data/logo.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

25
setup.py Normal file
View File

@@ -0,0 +1,25 @@
from setuptools import setup, find_packages
import os
repo_dir = os.path.dirname(os.path.abspath(__file__))
setup(name='baselines',
packages=[package for package in find_packages()
if package.startswith('baselines')],
install_requires=[
'scipy',
'tqdm',
'joblib',
'zmq',
'dill',
'tensorflow >= 1.0.0',
'azure==1.0.3',
'progressbar2',
],
description="OpenAI baselines: high quality implementations of reinforcement learning algorithms",
author="OpenAI",
url='https://github.com/openai/baselines',
author_email="gym@openai.com",
version="0.1.0")