ppo and trpo
This commit is contained in:
9
.gitignore
vendored
9
.gitignore
vendored
@@ -7,12 +7,11 @@
|
||||
# Setuptools distribution and build folders.
|
||||
/dist/
|
||||
/build
|
||||
keys/
|
||||
|
||||
# Virtualenv
|
||||
/env
|
||||
|
||||
# Python egg metadata, regenerated from source files by setuptools.
|
||||
/*.egg-info
|
||||
|
||||
*.sublime-project
|
||||
*.sublime-workspace
|
||||
@@ -26,4 +25,8 @@ ghostdriver.log
|
||||
|
||||
htmlcov
|
||||
|
||||
junk
|
||||
junk
|
||||
src
|
||||
|
||||
*.egg-info
|
||||
.cache
|
||||
|
60
README.md
60
README.md
@@ -1,8 +1,8 @@
|
||||
<img src="data/logo.jpg" width=25% align="right" />
|
||||
|
||||
# Baselines
|
||||
# BASELINES
|
||||
|
||||
We're releasing OpenAI Baselines, a set of high-quality implementations of reinforcement learning algorithms. To start, we're making available an open source version of Deep Q-Learning and three of its variants.
|
||||
We're releasing OpenAI Baselines, a set of high-quality implementations of reinforcement learning algorithms.
|
||||
|
||||
These algorithms will make it easier for the research community to replicate, refine, and identify new ideas, and will create good baselines to build research on top of. Our DQN implementation and its variants are roughly on par with the scores in published papers. We expect they will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones.
|
||||
|
||||
@@ -12,56 +12,6 @@ You can install it by typing:
|
||||
pip install baselines
|
||||
```
|
||||
|
||||
|
||||
## If you are curious.
|
||||
|
||||
##### Train a Cartpole agent and watch it play once it converges!
|
||||
|
||||
Here's a list of commands to run to quickly get a working example:
|
||||
|
||||
<img src="data/cartpole.gif" width="25%" />
|
||||
|
||||
|
||||
```bash
|
||||
# Train model and save the results to cartpole_model.pkl
|
||||
python -m baselines.deepq.experiments.train_cartpole
|
||||
# Load the model saved in cartpole_model.pkl and visualize the learned policy
|
||||
python -m baselines.deepq.experiments.enjoy_cartpole
|
||||
```
|
||||
|
||||
|
||||
Be sure to check out the source code of [both](baselines/deepq/experiments/train_cartpole.py) [files](baselines/deepq/experiments/enjoy_cartpole.py)!
|
||||
|
||||
## If you wish to apply DQN to solve a problem.
|
||||
|
||||
Check out our simple agent trained with one stop shop `deepq.learn` function.
|
||||
|
||||
- `baselines/deepq/experiments/train_cartpole.py` - train a Cartpole agent.
|
||||
- `baselines/deepq/experiments/train_pong.py` - train a Pong agent using convolutional neural networks.
|
||||
|
||||
In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. For both of the files listed above there are complimentary files `enjoy_cartpole.py` and `enjoy_pong.py` respectively, that load and visualize the learned policy.
|
||||
|
||||
## If you wish to experiment with the algorithm
|
||||
|
||||
##### Check out the examples
|
||||
|
||||
|
||||
- `baselines/deepq/experiments/custom_cartpole.py` - Cartpole training with more fine grained control over the internals of DQN algorithm.
|
||||
- `baselines/deepq/experiments/atari/train.py` - more robust setup for training at scale.
|
||||
|
||||
|
||||
##### Download a pretrained Atari agent
|
||||
|
||||
For some research projects it is sometimes useful to have an already trained agent handy. There's a variety of models to choose from. You can list them all by running:
|
||||
|
||||
```bash
|
||||
python -m baselines.deepq.experiments.atari.download_model
|
||||
```
|
||||
|
||||
Once you pick a model, you can download it and visualize the learned policy. Be sure to pass `--dueling` flag to visualization script when using dueling models.
|
||||
|
||||
```bash
|
||||
python -m baselines.deepq.experiments.atari.download_model --blob model-atari-duel-pong-1 --model-dir /tmp/models
|
||||
python -m baselines.deepq.experiments.atari.enjoy --model-dir /tmp/models/model-atari-duel-pong-1 --env Pong --dueling
|
||||
|
||||
```
|
||||
- [DQN](baselines/deepq)
|
||||
- [PPO](baselines/pposgd)
|
||||
- [TRPO](baselines/trpo_mpi)
|
||||
|
3
baselines/bench/__init__.py
Normal file
3
baselines/bench/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from baselines.bench.benchmarks import *
|
||||
from baselines.bench.monitor import *
|
||||
|
93
baselines/bench/benchmarks.py
Normal file
93
baselines/bench/benchmarks.py
Normal file
@@ -0,0 +1,93 @@
|
||||
_atari7 = ['BeamRider', 'Breakout', 'Enduro', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders']
|
||||
_atariexpl7 = ['Freeway', 'Gravitar', 'MontezumaRevenge', 'Pitfall', 'PrivateEye', 'Solaris', 'Venture']
|
||||
|
||||
_BENCHMARKS = []
|
||||
|
||||
def register_benchmark(benchmark):
|
||||
for b in _BENCHMARKS:
|
||||
if b['name'] == benchmark['name']:
|
||||
raise ValueError('Benchmark with name %s already registered!'%b['name'])
|
||||
_BENCHMARKS.append(benchmark)
|
||||
|
||||
def list_benchmarks():
|
||||
return [b['name'] for b in _BENCHMARKS]
|
||||
|
||||
def get_benchmark(benchmark_name):
|
||||
for b in _BENCHMARKS:
|
||||
if b['name'] == benchmark_name:
|
||||
return b
|
||||
raise ValueError('%s not found! Known benchmarks: %s' % (benchmark_name, list_benchmarks()))
|
||||
|
||||
def get_task(benchmark, env_id):
|
||||
"""Get a task by env_id. Return None if the benchmark doesn't have the env"""
|
||||
return next(filter(lambda task: task['env_id'] == env_id, benchmark['tasks']), None)
|
||||
|
||||
_ATARI_SUFFIX = 'NoFrameskip-v4'
|
||||
|
||||
register_benchmark({
|
||||
'name' : 'Atari200M',
|
||||
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 200M frames',
|
||||
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(200e6)} for _game in _atari7]
|
||||
})
|
||||
|
||||
register_benchmark({
|
||||
'name' : 'Atari40M',
|
||||
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 40M frames',
|
||||
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(40e6)} for _game in _atari7]
|
||||
})
|
||||
|
||||
register_benchmark({
|
||||
'name' : 'Atari1Hr',
|
||||
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 1 hour of walltime',
|
||||
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_seconds' : 60*60} for _game in _atari7]
|
||||
})
|
||||
|
||||
register_benchmark({
|
||||
'name' : 'AtariExploration40M',
|
||||
'description' :'7 Atari games emphasizing exploration, with pixel observations, 40M frames',
|
||||
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(40e6)} for _game in _atariexpl7]
|
||||
})
|
||||
|
||||
|
||||
_mujocosmall = [
|
||||
'InvertedDoublePendulum-v1', 'InvertedPendulum-v1',
|
||||
'HalfCheetah-v1', 'Hopper-v1', 'Walker2d-v1',
|
||||
'Reacher-v1', 'Swimmer-v1']
|
||||
|
||||
register_benchmark({
|
||||
'name' : 'Mujoco1M',
|
||||
'description' : 'Some small 2D MuJoCo tasks, run for 1M timesteps',
|
||||
'tasks' : [{'env_id' : _envid, 'trials' : 3, 'num_timesteps' : int(1e6)} for _envid in _mujocosmall]
|
||||
})
|
||||
|
||||
_roboschool_mujoco = [
|
||||
'RoboschoolInvertedDoublePendulum-v0', 'RoboschoolInvertedPendulum-v0', # cartpole
|
||||
'RoboschoolHalfCheetah-v0', 'RoboschoolHopper-v0', 'RoboschoolWalker2d-v0', # forward walkers
|
||||
'RoboschoolReacher-v0'
|
||||
]
|
||||
|
||||
register_benchmark({
|
||||
'name' : 'RoboschoolMujoco2M',
|
||||
'description' : 'Same small 2D tasks, still improving up to 2M',
|
||||
'tasks' : [{'env_id' : _envid, 'trials' : 3, 'num_timesteps' : int(2e6)} for _envid in _roboschool_mujoco]
|
||||
})
|
||||
|
||||
|
||||
_atari50 = [ # actually 49
|
||||
'Alien', 'Amidar', 'Assault', 'Asterix', 'Asteroids',
|
||||
'Atlantis', 'BankHeist', 'BattleZone', 'BeamRider', 'Bowling',
|
||||
'Boxing', 'Breakout', 'Centipede', 'ChopperCommand', 'CrazyClimber',
|
||||
'DemonAttack', 'DoubleDunk', 'Enduro', 'FishingDerby', 'Freeway',
|
||||
'Frostbite', 'Gopher', 'Gravitar', 'IceHockey', 'Jamesbond',
|
||||
'Kangaroo', 'Krull', 'KungFuMaster', 'MontezumaRevenge', 'MsPacman',
|
||||
'NameThisGame', 'Pitfall', 'Pong', 'PrivateEye', 'Qbert',
|
||||
'Riverraid', 'RoadRunner', 'Robotank', 'Seaquest', 'SpaceInvaders',
|
||||
'StarGunner', 'Tennis', 'TimePilot', 'Tutankham', 'UpNDown',
|
||||
'Venture', 'VideoPinball', 'WizardOfWor', 'Zaxxon',
|
||||
]
|
||||
|
||||
register_benchmark({
|
||||
'name' : 'Atari50_40M',
|
||||
'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 40M frames',
|
||||
'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 3, 'num_timesteps' : int(40e6)} for _game in _atari50]
|
||||
})
|
146
baselines/bench/monitor.py
Normal file
146
baselines/bench/monitor.py
Normal file
@@ -0,0 +1,146 @@
|
||||
__all__ = ['Monitor', 'get_monitor_files', 'load_results']
|
||||
|
||||
import gym
|
||||
from gym.core import Wrapper
|
||||
from os import path
|
||||
import time
|
||||
from glob import glob
|
||||
|
||||
try:
|
||||
import ujson as json # Not necessary for monitor writing, but very useful for monitor loading
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
class Monitor(Wrapper):
|
||||
EXT = "monitor.json"
|
||||
f = None
|
||||
|
||||
def __init__(self, env, filename, allow_early_resets=False):
|
||||
Wrapper.__init__(self, env=env)
|
||||
self.tstart = time.time()
|
||||
if filename is None:
|
||||
self.f = None
|
||||
self.logger = None
|
||||
else:
|
||||
if not filename.endswith(Monitor.EXT):
|
||||
filename = filename + "." + Monitor.EXT
|
||||
self.f = open(filename, "wt")
|
||||
self.logger = JSONLogger(self.f)
|
||||
self.logger.writekvs({"t_start": self.tstart, "gym_version": gym.__version__,
|
||||
"env_id": env.spec.id if env.spec else 'Unknown'})
|
||||
self.allow_early_resets = allow_early_resets
|
||||
self.rewards = None
|
||||
self.needs_reset = True
|
||||
self.episode_rewards = []
|
||||
self.episode_lengths = []
|
||||
self.total_steps = 0
|
||||
self.current_metadata = {} # extra info that gets injected into each log entry
|
||||
# Useful for metalearning where we're modifying the environment externally
|
||||
# But want our logs to know about these modifications
|
||||
|
||||
def __getstate__(self): # XXX
|
||||
d = self.__dict__.copy()
|
||||
if self.f:
|
||||
del d['f'], d['logger']
|
||||
d['_filename'] = self.f.name
|
||||
d['_num_episodes'] = len(self.episode_rewards)
|
||||
else:
|
||||
d['_filename'] = None
|
||||
return d
|
||||
def __setstate__(self, d):
|
||||
filename = d.pop('_filename')
|
||||
self.__dict__ = d
|
||||
if filename is not None:
|
||||
nlines = d.pop('_num_episodes') + 1
|
||||
self.f = open(filename, "r+t")
|
||||
for _ in range(nlines):
|
||||
self.f.readline()
|
||||
self.f.truncate()
|
||||
self.logger = JSONLogger(self.f)
|
||||
|
||||
|
||||
def reset(self):
|
||||
if not self.allow_early_resets and not self.needs_reset:
|
||||
raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)")
|
||||
self.rewards = []
|
||||
self.needs_reset = False
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, action):
|
||||
if self.needs_reset:
|
||||
raise RuntimeError("Tried to step environment that needs reset")
|
||||
ob, rew, done, info = self.env.step(action)
|
||||
self.rewards.append(rew)
|
||||
if done:
|
||||
self.needs_reset = True
|
||||
eprew = sum(self.rewards)
|
||||
eplen = len(self.rewards)
|
||||
epinfo = {"r": eprew, "l": eplen, "t": round(time.time() - self.tstart, 6)}
|
||||
epinfo.update(self.current_metadata)
|
||||
if self.logger:
|
||||
self.logger.writekvs(epinfo)
|
||||
self.episode_rewards.append(eprew)
|
||||
self.episode_lengths.append(eplen)
|
||||
info['episode'] = epinfo
|
||||
self.total_steps += 1
|
||||
return (ob, rew, done, info)
|
||||
|
||||
def close(self):
|
||||
if self.f is not None:
|
||||
self.f.close()
|
||||
|
||||
def get_total_steps(self):
|
||||
return self.total_steps
|
||||
|
||||
def get_episode_rewards(self):
|
||||
return self.episode_rewards
|
||||
|
||||
def get_episode_lengths(self):
|
||||
return self.episode_lengths
|
||||
|
||||
class JSONLogger(object):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def writekvs(self, kvs):
|
||||
for k,v in kvs.items():
|
||||
if hasattr(v, 'dtype'):
|
||||
v = v.tolist()
|
||||
kvs[k] = float(v)
|
||||
self.file.write(json.dumps(kvs) + '\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
class LoadMonitorResultsError(Exception):
|
||||
pass
|
||||
|
||||
def get_monitor_files(dir):
|
||||
return glob(path.join(dir, "*" + Monitor.EXT))
|
||||
|
||||
def load_results(dir):
|
||||
fnames = get_monitor_files(dir)
|
||||
if not fnames:
|
||||
raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir))
|
||||
episodes = []
|
||||
headers = []
|
||||
for fname in fnames:
|
||||
with open(fname, 'rt') as fh:
|
||||
lines = fh.readlines()
|
||||
header = json.loads(lines[0])
|
||||
headers.append(header)
|
||||
for line in lines[1:]:
|
||||
episode = json.loads(line)
|
||||
episode['abstime'] = header['t_start'] + episode['t']
|
||||
del episode['t']
|
||||
episodes.append(episode)
|
||||
header0 = headers[0]
|
||||
for header in headers[1:]:
|
||||
assert header['env_id'] == header0['env_id'], "mixing data from two envs"
|
||||
episodes = sorted(episodes, key=lambda e: e['abstime'])
|
||||
return {
|
||||
'env_info': {'env_id': header0['env_id'], 'gym_version': header0['gym_version']},
|
||||
'episode_end_times': [e['abstime'] for e in episodes],
|
||||
'episode_lengths': [e['l'] for e in episodes],
|
||||
'episode_rewards': [e['r'] for e in episodes],
|
||||
'initial_reset_time': min([min(header['t_start'] for header in headers)])
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
|
||||
|
||||
|
||||
from baselines.common.console_util import *
|
||||
from baselines.common.dataset import Dataset
|
||||
from baselines.common.math_util import *
|
||||
from baselines.common.misc_util import *
|
||||
|
172
baselines/common/atari_wrappers.py
Normal file
172
baselines/common/atari_wrappers.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from PIL import Image
|
||||
import gym
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def _reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Take action on reset for environments that are fixed until firing."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def _reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
|
||||
def _step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def _step(self, action):
|
||||
"""Repeat action, sum reward, and max over last observations."""
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
def _reward(self, reward):
|
||||
"""Bin reward to {+1, 0, -1} by its sign."""
|
||||
return np.sign(reward)
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
||||
gym.ObservationWrapper.__init__(self, env)
|
||||
self.res = 84
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(self.res, self.res, 1))
|
||||
|
||||
def _observation(self, obs):
|
||||
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
|
||||
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
|
||||
resample=Image.BILINEAR), dtype=np.uint8)
|
||||
return frame.reshape((self.res, self.res, 1))
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
"""Buffer observations and stack across channels (last axis)."""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
assert shp[2] == 1 # can only stack 1-channel frames
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k))
|
||||
|
||||
def _reset(self):
|
||||
"""Clear buffer and re-fill by duplicating the first observation."""
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k): self.frames.append(ob)
|
||||
return self._observation()
|
||||
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self._observation(), reward, done, info
|
||||
|
||||
def _observation(self):
|
||||
assert len(self.frames) == self.k
|
||||
return np.concatenate(self.frames, axis=2)
|
||||
|
||||
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
|
||||
"""Configure environment for DeepMind-style Atari.
|
||||
|
||||
Note: this does not include frame stacking!"""
|
||||
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
||||
if episode_life:
|
||||
env = EpisodicLifeEnv(env)
|
||||
# env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = WarpFrame(env)
|
||||
if clip_rewards:
|
||||
env = ClipRewardEnv(env)
|
||||
return env
|
34
baselines/common/cg.py
Normal file
34
baselines/common/cg.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
|
||||
"""
|
||||
Demmel p 312
|
||||
"""
|
||||
p = b.copy()
|
||||
r = b.copy()
|
||||
x = np.zeros_like(b)
|
||||
rdotr = r.dot(r)
|
||||
|
||||
fmtstr = "%10i %10.3g %10.3g"
|
||||
titlestr = "%10s %10s %10s"
|
||||
if verbose: print(titlestr % ("iter", "residual norm", "soln norm"))
|
||||
|
||||
for i in range(cg_iters):
|
||||
if callback is not None:
|
||||
callback(x)
|
||||
if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x)))
|
||||
z = f_Ax(p)
|
||||
v = rdotr / p.dot(z)
|
||||
x += v*p
|
||||
r -= v*z
|
||||
newrdotr = r.dot(r)
|
||||
mu = newrdotr/rdotr
|
||||
p = r + mu*p
|
||||
|
||||
rdotr = newrdotr
|
||||
if rdotr < residual_tol:
|
||||
break
|
||||
|
||||
if callback is not None:
|
||||
callback(x)
|
||||
if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631
|
||||
return x
|
54
baselines/common/console_util.py
Normal file
54
baselines/common/console_util.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import print_function
|
||||
from contextlib import contextmanager
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
# ================================================================
|
||||
# Misc
|
||||
# ================================================================
|
||||
|
||||
def fmt_row(width, row, header=False):
|
||||
out = " | ".join(fmt_item(x, width) for x in row)
|
||||
if header: out = out + "\n" + "-"*len(out)
|
||||
return out
|
||||
|
||||
def fmt_item(x, l):
|
||||
if isinstance(x, np.ndarray):
|
||||
assert x.ndim==0
|
||||
x = x.item()
|
||||
if isinstance(x, float): rep = "%g"%x
|
||||
else: rep = str(x)
|
||||
return " "*(l - len(rep)) + rep
|
||||
|
||||
color2num = dict(
|
||||
gray=30,
|
||||
red=31,
|
||||
green=32,
|
||||
yellow=33,
|
||||
blue=34,
|
||||
magenta=35,
|
||||
cyan=36,
|
||||
white=37,
|
||||
crimson=38
|
||||
)
|
||||
|
||||
def colorize(string, color, bold=False, highlight=False):
|
||||
attr = []
|
||||
num = color2num[color]
|
||||
if highlight: num += 10
|
||||
attr.append(str(num))
|
||||
if bold: attr.append('1')
|
||||
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
|
||||
|
||||
|
||||
MESSAGE_DEPTH = 0
|
||||
|
||||
@contextmanager
|
||||
def timed(msg):
|
||||
global MESSAGE_DEPTH #pylint: disable=W0603
|
||||
print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta'))
|
||||
tstart = time.time()
|
||||
MESSAGE_DEPTH += 1
|
||||
yield
|
||||
MESSAGE_DEPTH -= 1
|
||||
print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta'))
|
60
baselines/common/dataset.py
Normal file
60
baselines/common/dataset.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
|
||||
class Dataset(object):
|
||||
def __init__(self, data_map, deterministic=False, shuffle=True):
|
||||
self.data_map = data_map
|
||||
self.deterministic = deterministic
|
||||
self.enable_shuffle = shuffle
|
||||
self.n = next(iter(data_map.values())).shape[0]
|
||||
self._next_id = 0
|
||||
self.shuffle()
|
||||
|
||||
def shuffle(self):
|
||||
if self.deterministic:
|
||||
return
|
||||
perm = np.arange(self.n)
|
||||
np.random.shuffle(perm)
|
||||
|
||||
for key in self.data_map:
|
||||
self.data_map[key] = self.data_map[key][perm]
|
||||
|
||||
self._next_id = 0
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
if self._next_id >= self.n and self.enable_shuffle:
|
||||
self.shuffle()
|
||||
|
||||
cur_id = self._next_id
|
||||
cur_batch_size = min(batch_size, self.n - self._next_id)
|
||||
self._next_id += cur_batch_size
|
||||
|
||||
data_map = dict()
|
||||
for key in self.data_map:
|
||||
data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size]
|
||||
return data_map
|
||||
|
||||
def iterate_once(self, batch_size):
|
||||
if self.enable_shuffle: self.shuffle()
|
||||
|
||||
while self._next_id <= self.n - batch_size:
|
||||
yield self.next_batch(batch_size)
|
||||
self._next_id = 0
|
||||
|
||||
def subset(self, num_elements, deterministic=True):
|
||||
data_map = dict()
|
||||
for key in self.data_map:
|
||||
data_map[key] = self.data_map[key][:num_elements]
|
||||
return Dataset(data_map, deterministic)
|
||||
|
||||
|
||||
def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
|
||||
assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
|
||||
arrays = tuple(map(np.asarray, arrays))
|
||||
n = arrays[0].shape[0]
|
||||
assert all(a.shape[0] == n for a in arrays[1:])
|
||||
inds = np.arange(n)
|
||||
if shuffle: np.random.shuffle(inds)
|
||||
sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
|
||||
for batch_inds in np.array_split(inds, sections):
|
||||
if include_final_partial_batch or len(batch_inds) == batch_size:
|
||||
yield tuple(a[batch_inds] for a in arrays)
|
289
baselines/common/distributions.py
Normal file
289
baselines/common/distributions.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import baselines.common.tf_util as U
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
class Pd(object):
|
||||
"""
|
||||
A particular probability distribution
|
||||
"""
|
||||
def flatparam(self):
|
||||
raise NotImplementedError
|
||||
def mode(self):
|
||||
raise NotImplementedError
|
||||
def neglogp(self, x):
|
||||
# Usually it's easier to define the negative logprob
|
||||
raise NotImplementedError
|
||||
def kl(self, other):
|
||||
raise NotImplementedError
|
||||
def entropy(self):
|
||||
raise NotImplementedError
|
||||
def sample(self):
|
||||
raise NotImplementedError
|
||||
def logp(self, x):
|
||||
return - self.neglogp(x)
|
||||
|
||||
class PdType(object):
|
||||
"""
|
||||
Parametrized family of probability distributions
|
||||
"""
|
||||
def pdclass(self):
|
||||
raise NotImplementedError
|
||||
def pdfromflat(self, flat):
|
||||
return self.pdclass()(flat)
|
||||
def param_shape(self):
|
||||
raise NotImplementedError
|
||||
def sample_shape(self):
|
||||
raise NotImplementedError
|
||||
def sample_dtype(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def param_placeholder(self, prepend_shape, name=None):
|
||||
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
|
||||
def sample_placeholder(self, prepend_shape, name=None):
|
||||
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
|
||||
|
||||
class CategoricalPdType(PdType):
|
||||
def __init__(self, ncat):
|
||||
self.ncat = ncat
|
||||
def pdclass(self):
|
||||
return CategoricalPd
|
||||
def param_shape(self):
|
||||
return [self.ncat]
|
||||
def sample_shape(self):
|
||||
return []
|
||||
def sample_dtype(self):
|
||||
return tf.int32
|
||||
|
||||
|
||||
class MultiCategoricalPdType(PdType):
|
||||
def __init__(self, low, high):
|
||||
self.low = low
|
||||
self.high = high
|
||||
self.ncats = high - low + 1
|
||||
def pdclass(self):
|
||||
return MultiCategoricalPd
|
||||
def pdfromflat(self, flat):
|
||||
return MultiCategoricalPd(self.low, self.high, flat)
|
||||
def param_shape(self):
|
||||
return [sum(self.ncats)]
|
||||
def sample_shape(self):
|
||||
return [len(self.ncats)]
|
||||
def sample_dtype(self):
|
||||
return tf.int32
|
||||
|
||||
class DiagGaussianPdType(PdType):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def pdclass(self):
|
||||
return DiagGaussianPd
|
||||
def param_shape(self):
|
||||
return [2*self.size]
|
||||
def sample_shape(self):
|
||||
return [self.size]
|
||||
def sample_dtype(self):
|
||||
return tf.float32
|
||||
|
||||
class BernoulliPdType(PdType):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def pdclass(self):
|
||||
return BernoulliPd
|
||||
def param_shape(self):
|
||||
return [self.size]
|
||||
def sample_shape(self):
|
||||
return [self.size]
|
||||
def sample_dtype(self):
|
||||
return tf.int32
|
||||
|
||||
# WRONG SECOND DERIVATIVES
|
||||
# class CategoricalPd(Pd):
|
||||
# def __init__(self, logits):
|
||||
# self.logits = logits
|
||||
# self.ps = tf.nn.softmax(logits)
|
||||
# @classmethod
|
||||
# def fromflat(cls, flat):
|
||||
# return cls(flat)
|
||||
# def flatparam(self):
|
||||
# return self.logits
|
||||
# def mode(self):
|
||||
# return U.argmax(self.logits, axis=1)
|
||||
# def logp(self, x):
|
||||
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
|
||||
# def kl(self, other):
|
||||
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
|
||||
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
|
||||
# def entropy(self):
|
||||
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
|
||||
# def sample(self):
|
||||
# u = tf.random_uniform(tf.shape(self.logits))
|
||||
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
|
||||
|
||||
class CategoricalPd(Pd):
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
def flatparam(self):
|
||||
return self.logits
|
||||
def mode(self):
|
||||
return U.argmax(self.logits, axis=1)
|
||||
def neglogp(self, x):
|
||||
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
|
||||
def kl(self, other):
|
||||
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
|
||||
a1 = other.logits - U.max(other.logits, axis=1, keepdims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
ea1 = tf.exp(a1)
|
||||
z0 = U.sum(ea0, axis=1, keepdims=True)
|
||||
z1 = U.sum(ea1, axis=1, keepdims=True)
|
||||
p0 = ea0 / z0
|
||||
return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=1)
|
||||
def entropy(self):
|
||||
a0 = self.logits - U.max(self.logits, axis=1, keepdims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
z0 = U.sum(ea0, axis=1, keepdims=True)
|
||||
p0 = ea0 / z0
|
||||
return U.sum(p0 * (tf.log(z0) - a0), axis=1)
|
||||
def sample(self):
|
||||
u = tf.random_uniform(tf.shape(self.logits))
|
||||
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=1)
|
||||
@classmethod
|
||||
def fromflat(cls, flat):
|
||||
return cls(flat)
|
||||
|
||||
class MultiCategoricalPd(Pd):
|
||||
def __init__(self, low, high, flat):
|
||||
self.flat = flat
|
||||
self.low = tf.constant(low, dtype=tf.int32)
|
||||
self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1)))
|
||||
def flatparam(self):
|
||||
return self.flat
|
||||
def mode(self):
|
||||
return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
|
||||
def neglogp(self, x):
|
||||
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))])
|
||||
def kl(self, other):
|
||||
return tf.add_n([
|
||||
p.kl(q) for p, q in zip(self.categoricals, other.categoricals)
|
||||
])
|
||||
def entropy(self):
|
||||
return tf.add_n([p.entropy() for p in self.categoricals])
|
||||
def sample(self):
|
||||
return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
|
||||
@classmethod
|
||||
def fromflat(cls, flat):
|
||||
raise NotImplementedError
|
||||
|
||||
class DiagGaussianPd(Pd):
|
||||
def __init__(self, flat):
|
||||
self.flat = flat
|
||||
mean, logstd = tf.split(axis=len(flat.get_shape()) - 1, num_or_size_splits=2, value=flat)
|
||||
self.mean = mean
|
||||
self.logstd = logstd
|
||||
self.std = tf.exp(logstd)
|
||||
def flatparam(self):
|
||||
return self.flat
|
||||
def mode(self):
|
||||
return self.mean
|
||||
def neglogp(self, x):
|
||||
return 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=len(x.get_shape()) - 1) \
|
||||
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
|
||||
+ U.sum(self.logstd, axis=len(x.get_shape()) - 1)
|
||||
def kl(self, other):
|
||||
assert isinstance(other, DiagGaussianPd)
|
||||
return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
|
||||
def entropy(self):
|
||||
return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), -1)
|
||||
def sample(self):
|
||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
@classmethod
|
||||
def fromflat(cls, flat):
|
||||
return cls(flat)
|
||||
|
||||
class BernoulliPd(Pd):
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
self.ps = tf.sigmoid(logits)
|
||||
def flatparam(self):
|
||||
return self.logits
|
||||
def mode(self):
|
||||
return tf.round(self.ps)
|
||||
def neglogp(self, x):
|
||||
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=1)
|
||||
def kl(self, other):
|
||||
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
|
||||
def entropy(self):
|
||||
return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=1)
|
||||
def sample(self):
|
||||
u = tf.random_uniform(tf.shape(self.ps))
|
||||
return tf.to_float(math_ops.less(u, self.ps))
|
||||
@classmethod
|
||||
def fromflat(cls, flat):
|
||||
return cls(flat)
|
||||
|
||||
def make_pdtype(ac_space):
|
||||
from gym import spaces
|
||||
if isinstance(ac_space, spaces.Box):
|
||||
assert len(ac_space.shape) == 1
|
||||
return DiagGaussianPdType(ac_space.shape[0])
|
||||
elif isinstance(ac_space, spaces.Discrete):
|
||||
return CategoricalPdType(ac_space.n)
|
||||
elif isinstance(ac_space, spaces.MultiDiscrete):
|
||||
return MultiCategoricalPdType(ac_space.low, ac_space.high)
|
||||
elif isinstance(ac_space, spaces.MultiBinary):
|
||||
return BernoulliPdType(ac_space.n)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def shape_el(v, i):
|
||||
maybe = v.get_shape()[i]
|
||||
if maybe is not None:
|
||||
return maybe
|
||||
else:
|
||||
return tf.shape(v)[i]
|
||||
|
||||
@U.in_session
|
||||
def test_probtypes():
|
||||
np.random.seed(0)
|
||||
|
||||
pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
|
||||
diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
|
||||
validate_probtype(diag_gauss, pdparam_diag_gauss)
|
||||
|
||||
pdparam_categorical = np.array([-.2, .3, .5])
|
||||
categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
|
||||
validate_probtype(categorical, pdparam_categorical)
|
||||
|
||||
pdparam_bernoulli = np.array([-.2, .3, .5])
|
||||
bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
|
||||
validate_probtype(bernoulli, pdparam_bernoulli)
|
||||
|
||||
|
||||
def validate_probtype(probtype, pdparam):
|
||||
N = 100000
|
||||
# Check to see if mean negative log likelihood == differential entropy
|
||||
Mval = np.repeat(pdparam[None, :], N, axis=0)
|
||||
M = probtype.param_placeholder([N])
|
||||
X = probtype.sample_placeholder([N])
|
||||
pd = probtype.pdclass()(M)
|
||||
calcloglik = U.function([X, M], pd.logp(X))
|
||||
calcent = U.function([M], pd.entropy())
|
||||
Xval = U.eval(pd.sample(), feed_dict={M:Mval})
|
||||
logliks = calcloglik(Xval, Mval)
|
||||
entval_ll = - logliks.mean() #pylint: disable=E1101
|
||||
entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
|
||||
entval = calcent(Mval).mean() #pylint: disable=E1101
|
||||
assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
|
||||
|
||||
# Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
|
||||
M2 = probtype.param_placeholder([N])
|
||||
pd2 = probtype.pdclass()(M2)
|
||||
q = pdparam + np.random.randn(pdparam.size) * 0.1
|
||||
Mval2 = np.repeat(q[None, :], N, axis=0)
|
||||
calckl = U.function([M, M2], pd.kl(pd2))
|
||||
klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
|
||||
logliks = calcloglik(Xval, Mval2)
|
||||
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
|
||||
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
|
||||
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
|
||||
|
85
baselines/common/math_util.py
Normal file
85
baselines/common/math_util.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
|
||||
|
||||
def discount(x, gamma):
|
||||
"""
|
||||
computes discounted sums along 0th dimension of x.
|
||||
|
||||
inputs
|
||||
------
|
||||
x: ndarray
|
||||
gamma: float
|
||||
|
||||
outputs
|
||||
-------
|
||||
y: ndarray with same shape as x, satisfying
|
||||
|
||||
y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k],
|
||||
where k = len(x) - t - 1
|
||||
|
||||
"""
|
||||
assert x.ndim >= 1
|
||||
return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1]
|
||||
|
||||
def explained_variance(ypred,y):
|
||||
"""
|
||||
Computes fraction of variance that ypred explains about y.
|
||||
Returns 1 - Var[y-ypred] / Var[y]
|
||||
|
||||
interpretation:
|
||||
ev=0 => might as well have predicted zero
|
||||
ev=1 => perfect prediction
|
||||
ev<0 => worse than just predicting zero
|
||||
|
||||
"""
|
||||
assert y.ndim == 1 and ypred.ndim == 1
|
||||
vary = np.var(y)
|
||||
return np.nan if vary==0 else 1 - np.var(y-ypred)/vary
|
||||
|
||||
def explained_variance_2d(ypred, y):
|
||||
assert y.ndim == 2 and ypred.ndim == 2
|
||||
vary = np.var(y, axis=0)
|
||||
out = 1 - np.var(y-ypred)/vary
|
||||
out[vary < 1e-10] = 0
|
||||
return out
|
||||
|
||||
def ncc(ypred, y):
|
||||
return np.corrcoef(ypred, y)[1,0]
|
||||
|
||||
def flatten_arrays(arrs):
|
||||
return np.concatenate([arr.flat for arr in arrs])
|
||||
|
||||
def unflatten_vector(vec, shapes):
|
||||
i=0
|
||||
arrs = []
|
||||
for shape in shapes:
|
||||
size = np.prod(shape)
|
||||
arr = vec[i:i+size].reshape(shape)
|
||||
arrs.append(arr)
|
||||
i += size
|
||||
return arrs
|
||||
|
||||
def discount_with_boundaries(X, New, gamma):
|
||||
"""
|
||||
X: 2d array of floats, time x features
|
||||
New: 2d array of bools, indicating when a new episode has started
|
||||
"""
|
||||
Y = np.zeros_like(X)
|
||||
T = X.shape[0]
|
||||
Y[T-1] = X[T-1]
|
||||
for t in range(T-2, -1, -1):
|
||||
Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1])
|
||||
return Y
|
||||
|
||||
def test_discount_with_boundaries():
|
||||
gamma=0.9
|
||||
x = np.array([1.0, 2.0, 3.0, 4.0], 'float32')
|
||||
starts = [1.0, 0.0, 0.0, 1.0]
|
||||
y = discount_with_boundaries(x, starts, gamma)
|
||||
assert np.allclose(y, [
|
||||
1 + gamma * 2 + gamma**2 * 3,
|
||||
2 + gamma * 3,
|
||||
3,
|
||||
4
|
||||
])
|
@@ -155,7 +155,7 @@ class RunningAvg(object):
|
||||
|
||||
|
||||
class SimpleMonitor(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
def __init__(self, env):
|
||||
"""Adds two qunatities to info returned by every step:
|
||||
|
||||
num_steps: int
|
||||
|
79
baselines/common/mpi_adam.py
Normal file
79
baselines/common/mpi_adam.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from mpi4py import MPI
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
class MpiAdam(object):
|
||||
def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
|
||||
self.var_list = var_list
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.scale_grad_by_procs = scale_grad_by_procs
|
||||
size = sum(U.numel(v) for v in var_list)
|
||||
self.m = np.zeros(size, 'float32')
|
||||
self.v = np.zeros(size, 'float32')
|
||||
self.t = 0
|
||||
self.setfromflat = U.SetFromFlat(var_list)
|
||||
self.getflat = U.GetFlat(var_list)
|
||||
self.comm = MPI.COMM_WORLD if comm is None else comm
|
||||
|
||||
def update(self, localg, stepsize):
|
||||
if self.t % 100 == 0:
|
||||
self.check_synced()
|
||||
localg = localg.astype('float32')
|
||||
globalg = np.zeros_like(localg)
|
||||
self.comm.Allreduce(localg, globalg, op=MPI.SUM)
|
||||
if self.scale_grad_by_procs:
|
||||
globalg /= self.comm.Get_size()
|
||||
|
||||
self.t += 1
|
||||
a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
|
||||
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
|
||||
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
|
||||
step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon)
|
||||
self.setfromflat(self.getflat() + step)
|
||||
|
||||
def sync(self):
|
||||
theta = self.getflat()
|
||||
self.comm.Bcast(theta, root=0)
|
||||
self.setfromflat(theta)
|
||||
|
||||
def check_synced(self):
|
||||
if self.comm.Get_rank() == 0: # this is root
|
||||
theta = self.getflat()
|
||||
self.comm.Bcast(theta, root=0)
|
||||
else:
|
||||
thetalocal = self.getflat()
|
||||
thetaroot = np.empty_like(thetalocal)
|
||||
self.comm.Bcast(thetaroot, root=0)
|
||||
assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
|
||||
|
||||
@U.in_session
|
||||
def test_MpiAdam():
|
||||
np.random.seed(0)
|
||||
tf.set_random_seed(0)
|
||||
|
||||
a = tf.Variable(np.random.randn(3).astype('float32'))
|
||||
b = tf.Variable(np.random.randn(2,5).astype('float32'))
|
||||
loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))
|
||||
|
||||
stepsize = 1e-2
|
||||
update_op = tf.train.AdamOptimizer(stepsize).minimize(loss)
|
||||
do_update = U.function([], loss, updates=[update_op])
|
||||
|
||||
tf.get_default_session().run(tf.global_variables_initializer())
|
||||
for i in range(10):
|
||||
print(i,do_update())
|
||||
|
||||
tf.set_random_seed(0)
|
||||
tf.get_default_session().run(tf.global_variables_initializer())
|
||||
|
||||
var_list = [a,b]
|
||||
lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op])
|
||||
adam = MpiAdam(var_list)
|
||||
|
||||
for i in range(10):
|
||||
l,g = lossandgrad()
|
||||
adam.update(g, stepsize)
|
||||
print(i,l)
|
19
baselines/common/mpi_fork.py
Normal file
19
baselines/common/mpi_fork.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os, subprocess, sys
|
||||
|
||||
def mpi_fork(n):
|
||||
"""Re-launches the current script with workers
|
||||
Returns "parent" for original parent, "child" for MPI children
|
||||
"""
|
||||
if n<=1:
|
||||
return "child"
|
||||
if os.getenv("IN_MPI") is None:
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
MKL_NUM_THREADS="1",
|
||||
OMP_NUM_THREADS="1",
|
||||
IN_MPI="1"
|
||||
)
|
||||
subprocess.check_call(["mpirun", "-np", str(n), sys.executable] + sys.argv, env=env)
|
||||
return "parent"
|
||||
else:
|
||||
return "child"
|
50
baselines/common/mpi_moments.py
Normal file
50
baselines/common/mpi_moments.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from mpi4py import MPI
|
||||
import numpy as np
|
||||
from baselines.common import zipsame
|
||||
|
||||
def mpi_moments(x, axis=0):
|
||||
x = np.asarray(x, dtype='float64')
|
||||
newshape = list(x.shape)
|
||||
newshape.pop(axis)
|
||||
n = np.prod(newshape,dtype=int)
|
||||
totalvec = np.zeros(n*2+1, 'float64')
|
||||
addvec = np.concatenate([x.sum(axis=axis).ravel(),
|
||||
np.square(x).sum(axis=axis).ravel(),
|
||||
np.array([x.shape[axis]],dtype='float64')])
|
||||
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
|
||||
sum = totalvec[:n]
|
||||
sumsq = totalvec[n:2*n]
|
||||
count = totalvec[2*n]
|
||||
if count == 0:
|
||||
mean = np.empty(newshape); mean[:] = np.nan
|
||||
std = np.empty(newshape); std[:] = np.nan
|
||||
else:
|
||||
mean = sum/count
|
||||
std = np.sqrt(np.maximum(sumsq/count - np.square(mean),0))
|
||||
return mean, std, count
|
||||
|
||||
|
||||
def test_runningmeanstd():
|
||||
comm = MPI.COMM_WORLD
|
||||
np.random.seed(0)
|
||||
for (triple,axis) in [
|
||||
((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0),
|
||||
((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0),
|
||||
((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1),
|
||||
]:
|
||||
|
||||
|
||||
x = np.concatenate(triple, axis=axis)
|
||||
ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]]
|
||||
|
||||
|
||||
ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis)
|
||||
|
||||
for (a1,a2) in zipsame(ms1, ms2):
|
||||
print(a1, a2)
|
||||
assert np.allclose(a1, a2)
|
||||
print("ok!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
#mpirun -np 3 python <script>
|
||||
test_runningmeanstd()
|
107
baselines/common/mpi_running_mean_std.py
Normal file
107
baselines/common/mpi_running_mean_std.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from mpi4py import MPI
|
||||
import tensorflow as tf, baselines.common.tf_util as U, numpy as np
|
||||
|
||||
class RunningMeanStd(object):
|
||||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
def __init__(self, epsilon=1e-2, shape=()):
|
||||
|
||||
self._sum = tf.get_variable(
|
||||
dtype=tf.float64,
|
||||
shape=shape,
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
name="runningsum", trainable=False)
|
||||
self._sumsq = tf.get_variable(
|
||||
dtype=tf.float64,
|
||||
shape=shape,
|
||||
initializer=tf.constant_initializer(epsilon),
|
||||
name="runningsumsq", trainable=False)
|
||||
self._count = tf.get_variable(
|
||||
dtype=tf.float64,
|
||||
shape=(),
|
||||
initializer=tf.constant_initializer(epsilon),
|
||||
name="count", trainable=False)
|
||||
self.shape = shape
|
||||
|
||||
self.mean = tf.to_float(self._sum / self._count)
|
||||
self.std = tf.sqrt( tf.maximum( tf.to_float(self._sumsq / self._count) - tf.square(self.mean) , 1e-2 ))
|
||||
|
||||
newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
|
||||
newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
|
||||
newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
|
||||
self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
|
||||
updates=[tf.assign_add(self._sum, newsum),
|
||||
tf.assign_add(self._sumsq, newsumsq),
|
||||
tf.assign_add(self._count, newcount)])
|
||||
|
||||
|
||||
def update(self, x):
|
||||
x = x.astype('float64')
|
||||
n = int(np.prod(self.shape))
|
||||
totalvec = np.zeros(n*2+1, 'float64')
|
||||
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
|
||||
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
|
||||
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
|
||||
|
||||
@U.in_session
|
||||
def test_runningmeanstd():
|
||||
for (x1, x2, x3) in [
|
||||
(np.random.randn(3), np.random.randn(4), np.random.randn(5)),
|
||||
(np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
|
||||
]:
|
||||
|
||||
rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
|
||||
U.initialize()
|
||||
|
||||
x = np.concatenate([x1, x2, x3], axis=0)
|
||||
ms1 = [x.mean(axis=0), x.std(axis=0)]
|
||||
rms.update(x1)
|
||||
rms.update(x2)
|
||||
rms.update(x3)
|
||||
ms2 = U.eval([rms.mean, rms.std])
|
||||
|
||||
assert np.allclose(ms1, ms2)
|
||||
|
||||
@U.in_session
|
||||
def test_dist():
|
||||
np.random.seed(0)
|
||||
p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1))
|
||||
q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1))
|
||||
|
||||
# p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
|
||||
# q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
assert comm.Get_size()==2
|
||||
if comm.Get_rank()==0:
|
||||
x1,x2,x3 = p1,p2,p3
|
||||
elif comm.Get_rank()==1:
|
||||
x1,x2,x3 = q1,q2,q3
|
||||
else:
|
||||
assert False
|
||||
|
||||
rms = RunningMeanStd(epsilon=0.0, shape=(1,))
|
||||
U.initialize()
|
||||
|
||||
rms.update(x1)
|
||||
rms.update(x2)
|
||||
rms.update(x3)
|
||||
|
||||
bigvec = np.concatenate([p1,p2,p3,q1,q2,q3])
|
||||
|
||||
def checkallclose(x,y):
|
||||
print(x,y)
|
||||
return np.allclose(x,y)
|
||||
|
||||
assert checkallclose(
|
||||
bigvec.mean(axis=0),
|
||||
U.eval(rms.mean)
|
||||
)
|
||||
assert checkallclose(
|
||||
bigvec.std(axis=0),
|
||||
U.eval(rms.std)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run with mpirun -np 2 python <filename>
|
||||
test_dist()
|
@@ -606,8 +606,10 @@ def intprod(x):
|
||||
return int(np.prod(x))
|
||||
|
||||
|
||||
def flatgrad(loss, var_list):
|
||||
def flatgrad(loss, var_list, clip_norm=None):
|
||||
grads = tf.gradients(loss, var_list)
|
||||
if clip_norm is not None:
|
||||
grads = [tf.clip_by_norm(grad, clip_norm=clip_norm) for grad in grads]
|
||||
return tf.concat(axis=0, values=[
|
||||
tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)])
|
||||
for (v, grad) in zip(var_list, grads)
|
||||
|
52
baselines/deepq/README.md
Normal file
52
baselines/deepq/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
## If you are curious.
|
||||
|
||||
##### Train a Cartpole agent and watch it play once it converges!
|
||||
|
||||
Here's a list of commands to run to quickly get a working example:
|
||||
|
||||
<img src="../../data/cartpole.gif" width="25%" />
|
||||
|
||||
|
||||
```bash
|
||||
# Train model and save the results to cartpole_model.pkl
|
||||
python -m baselines.deepq.experiments.train_cartpole
|
||||
# Load the model saved in cartpole_model.pkl and visualize the learned policy
|
||||
python -m baselines.deepq.experiments.enjoy_cartpole
|
||||
```
|
||||
|
||||
|
||||
Be sure to check out the source code of [both](baselines/deepq/experiments/train_cartpole.py) [files](baselines/deepq/experiments/enjoy_cartpole.py)!
|
||||
|
||||
## If you wish to apply DQN to solve a problem.
|
||||
|
||||
Check out our simple agent trained with one stop shop `deepq.learn` function.
|
||||
|
||||
- `baselines/deepq/experiments/train_cartpole.py` - train a Cartpole agent.
|
||||
- `baselines/deepq/experiments/train_pong.py` - train a Pong agent using convolutional neural networks.
|
||||
|
||||
In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. For both of the files listed above there are complimentary files `enjoy_cartpole.py` and `enjoy_pong.py` respectively, that load and visualize the learned policy.
|
||||
|
||||
## If you wish to experiment with the algorithm
|
||||
|
||||
##### Check out the examples
|
||||
|
||||
|
||||
- `baselines/deepq/experiments/custom_cartpole.py` - Cartpole training with more fine grained control over the internals of DQN algorithm.
|
||||
- `baselines/deepq/experiments/atari/train.py` - more robust setup for training at scale.
|
||||
|
||||
|
||||
##### Download a pretrained Atari agent
|
||||
|
||||
For some research projects it is sometimes useful to have an already trained agent handy. There's a variety of models to choose from. You can list them all by running:
|
||||
|
||||
```bash
|
||||
python -m baselines.deepq.experiments.atari.download_model
|
||||
```
|
||||
|
||||
Once you pick a model, you can download it and visualize the learned policy. Be sure to pass `--dueling` flag to visualization script when using dueling models.
|
||||
|
||||
```bash
|
||||
python -m baselines.deepq.experiments.atari.download_model --blob model-atari-duel-pong-1 --model-dir /tmp/models
|
||||
python -m baselines.deepq.experiments.atari.enjoy --model-dir /tmp/models/model-atari-duel-pong-1 --env Pong --dueling
|
||||
|
||||
```
|
@@ -99,7 +99,7 @@ def learn(env,
|
||||
|
||||
Parameters
|
||||
-------
|
||||
env : gym.Env
|
||||
env: gym.Env
|
||||
environment to train on
|
||||
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||
the model that takes the following inputs:
|
||||
@@ -123,6 +123,7 @@ def learn(env,
|
||||
final value of random action probability
|
||||
train_freq: int
|
||||
update the model every `train_freq` steps.
|
||||
set to None to disable printing
|
||||
batch_size: int
|
||||
size of a batched sampled from replay buffer for training
|
||||
print_freq: int
|
||||
|
@@ -13,6 +13,9 @@ import sys
|
||||
import shutil
|
||||
import os.path as osp
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import tempfile
|
||||
|
||||
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json']
|
||||
|
||||
@@ -23,7 +26,6 @@ ERROR = 40
|
||||
|
||||
DISABLED = 50
|
||||
|
||||
|
||||
class OutputFormat(object):
|
||||
def writekvs(self, kvs):
|
||||
"""
|
||||
@@ -81,7 +83,6 @@ class HumanOutputFormat(OutputFormat):
|
||||
self.file.write('\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
class JSONOutputFormat(OutputFormat):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
@@ -94,6 +95,41 @@ class JSONOutputFormat(OutputFormat):
|
||||
self.file.write(json.dumps(kvs) + '\n')
|
||||
self.file.flush()
|
||||
|
||||
class TensorBoardOutputFormat(OutputFormat):
|
||||
"""
|
||||
Dumps key/value pairs into TensorBoard's numeric format.
|
||||
"""
|
||||
def __init__(self, dir):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
self.dir = dir
|
||||
self.step = 1
|
||||
prefix = 'events'
|
||||
path = osp.join(osp.abspath(dir), prefix)
|
||||
import tensorflow as tf
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.util import compat
|
||||
self.tf = tf
|
||||
self.event_pb2 = event_pb2
|
||||
self.pywrap_tensorflow = pywrap_tensorflow
|
||||
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
|
||||
|
||||
def writekvs(self, kvs):
|
||||
def summary_val(k, v):
|
||||
kwargs = {'tag': k, 'simple_value': float(v)}
|
||||
return self.tf.Summary.Value(**kwargs)
|
||||
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
|
||||
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
|
||||
event.step = self.step # is there any reason why you'd want to specify the step?
|
||||
self.writer.WriteEvent(event)
|
||||
self.writer.Flush()
|
||||
self.step += 1
|
||||
|
||||
def close(self):
|
||||
if self.writer:
|
||||
self.writer.Close()
|
||||
self.writer = None
|
||||
|
||||
|
||||
def make_output_format(format, ev_dir):
|
||||
os.makedirs(ev_dir, exist_ok=True)
|
||||
@@ -105,6 +141,8 @@ def make_output_format(format, ev_dir):
|
||||
elif format == 'json':
|
||||
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
|
||||
return JSONOutputFormat(json_file)
|
||||
elif format == 'tensorboard':
|
||||
return TensorBoardOutputFormat(osp.join(ev_dir, 'tb'))
|
||||
else:
|
||||
raise ValueError('Unknown format specified: %s' % (format,))
|
||||
|
||||
@@ -173,12 +211,6 @@ def get_dir():
|
||||
"""
|
||||
return Logger.CURRENT.get_dir()
|
||||
|
||||
|
||||
def get_expt_dir():
|
||||
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),))
|
||||
return get_dir()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Backend
|
||||
# ================================================================
|
||||
@@ -241,17 +273,25 @@ class session(object):
|
||||
|
||||
CURRENT = None # Set to a LoggerContext object using enter/exit or context manager
|
||||
|
||||
def __init__(self, dir, format_strs=None):
|
||||
def __init__(self, dir=None, format_strs=None):
|
||||
if dir is None:
|
||||
dir = os.getenv('OPENAI_LOGDIR')
|
||||
if dir is None:
|
||||
dir = osp.join(tempfile.gettempdir(),
|
||||
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
|
||||
self.dir = dir
|
||||
if format_strs is None:
|
||||
format_strs = LOG_OUTPUT_FORMATS
|
||||
output_formats = [make_output_format(f, dir) for f in format_strs]
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||
print('Logging to', dir)
|
||||
|
||||
def __enter__(self):
|
||||
os.makedirs(self.evaluation_dir(), exist_ok=True)
|
||||
output_formats = [make_output_format(f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
|
||||
output_formats = [make_output_format(f, self.evaluation_dir())
|
||||
for f in LOG_OUTPUT_FORMATS]
|
||||
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
|
||||
os.environ['OPENAI_LOGDIR'] = self.evaluation_dir()
|
||||
|
||||
def __exit__(self, *args):
|
||||
Logger.CURRENT.close()
|
||||
@@ -260,6 +300,12 @@ class session(object):
|
||||
def evaluation_dir(self):
|
||||
return self.dir
|
||||
|
||||
def _setup():
|
||||
logdir = os.getenv('OPENAI_LOGDIR')
|
||||
if logdir:
|
||||
session(logdir).__enter__()
|
||||
|
||||
_setup()
|
||||
|
||||
# ================================================================
|
||||
|
||||
|
57
baselines/pposgd/cnn_policy.py
Normal file
57
baselines/pposgd/cnn_policy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
from baselines.common.distributions import make_pdtype
|
||||
|
||||
class CnnPolicy(object):
|
||||
recurrent = False
|
||||
def __init__(self, name, ob_space, ac_space, kind='large'):
|
||||
with tf.variable_scope(name):
|
||||
self._init(ob_space, ac_space, kind)
|
||||
self.scope = tf.get_variable_scope().name
|
||||
|
||||
def _init(self, ob_space, ac_space, kind):
|
||||
assert isinstance(ob_space, gym.spaces.Box)
|
||||
|
||||
self.pdtype = pdtype = make_pdtype(ac_space)
|
||||
sequence_length = None
|
||||
|
||||
ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
|
||||
|
||||
x = ob / 255.0
|
||||
if kind == 'small': # from A3C paper
|
||||
x = tf.nn.relu(U.conv2d(x, 16, "l1", [8, 8], [4, 4], pad="VALID"))
|
||||
x = tf.nn.relu(U.conv2d(x, 32, "l2", [4, 4], [2, 2], pad="VALID"))
|
||||
x = U.flattenallbut0(x)
|
||||
x = tf.nn.relu(U.dense(x, 256, 'lin', U.normc_initializer(1.0)))
|
||||
elif kind == 'large': # Nature DQN
|
||||
x = tf.nn.relu(U.conv2d(x, 32, "l1", [8, 8], [4, 4], pad="VALID"))
|
||||
x = tf.nn.relu(U.conv2d(x, 64, "l2", [4, 4], [2, 2], pad="VALID"))
|
||||
x = tf.nn.relu(U.conv2d(x, 64, "l3", [3, 3], [1, 1], pad="VALID"))
|
||||
x = U.flattenallbut0(x)
|
||||
x = tf.nn.relu(U.dense(x, 512, 'lin', U.normc_initializer(1.0)))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
logits = U.dense(x, pdtype.param_shape()[0], "logits", U.normc_initializer(0.01))
|
||||
self.pd = pdtype.pdfromflat(logits)
|
||||
self.vpred = U.dense(x, 1, "value", U.normc_initializer(1.0))[:,0]
|
||||
|
||||
self.state_in = []
|
||||
self.state_out = []
|
||||
|
||||
stochastic = tf.placeholder(dtype=tf.bool, shape=())
|
||||
ac = self.pd.sample() # XXX
|
||||
self._act = U.function([stochastic, ob], [ac, self.vpred])
|
||||
|
||||
def act(self, stochastic, ob):
|
||||
ac1, vpred1 = self._act(stochastic, ob[None])
|
||||
return ac1[0], vpred1[0]
|
||||
def get_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
|
||||
def get_trainable_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
|
||||
def get_initial_state(self):
|
||||
return []
|
||||
|
59
baselines/pposgd/mlp_policy.py
Normal file
59
baselines/pposgd/mlp_policy.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
from baselines.common.distributions import make_pdtype
|
||||
|
||||
class MlpPolicy(object):
|
||||
recurrent = False
|
||||
def __init__(self, name, *args, **kwargs):
|
||||
with tf.variable_scope(name):
|
||||
self._init(*args, **kwargs)
|
||||
self.scope = tf.get_variable_scope().name
|
||||
|
||||
def _init(self, ob_space, ac_space, hid_size, num_hid_layers, gaussian_fixed_var=True):
|
||||
assert isinstance(ob_space, gym.spaces.Box)
|
||||
|
||||
self.pdtype = pdtype = make_pdtype(ac_space)
|
||||
sequence_length = None
|
||||
|
||||
ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
|
||||
|
||||
with tf.variable_scope("obfilter"):
|
||||
self.ob_rms = RunningMeanStd(shape=ob_space.shape)
|
||||
|
||||
obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
|
||||
last_out = obz
|
||||
for i in range(num_hid_layers):
|
||||
last_out = tf.nn.tanh(U.dense(last_out, hid_size, "vffc%i"%(i+1), weight_init=U.normc_initializer(1.0)))
|
||||
self.vpred = U.dense(last_out, 1, "vffinal", weight_init=U.normc_initializer(1.0))[:,0]
|
||||
|
||||
last_out = obz
|
||||
for i in range(num_hid_layers):
|
||||
last_out = tf.nn.tanh(U.dense(last_out, hid_size, "polfc%i"%(i+1), weight_init=U.normc_initializer(1.0)))
|
||||
if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
|
||||
mean = U.dense(last_out, pdtype.param_shape()[0]//2, "polfinal", U.normc_initializer(0.01))
|
||||
logstd = tf.get_variable(name="logstd", shape=[1, pdtype.param_shape()[0]//2], initializer=tf.zeros_initializer())
|
||||
pdparam = U.concatenate([mean, mean * 0.0 + logstd], axis=1)
|
||||
else:
|
||||
pdparam = U.dense(last_out, pdtype.param_shape()[0], "polfinal", U.normc_initializer(0.01))
|
||||
|
||||
self.pd = pdtype.pdfromflat(pdparam)
|
||||
|
||||
self.state_in = []
|
||||
self.state_out = []
|
||||
|
||||
stochastic = tf.placeholder(dtype=tf.bool, shape=())
|
||||
ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
|
||||
self._act = U.function([stochastic, ob], [ac, self.vpred])
|
||||
|
||||
def act(self, stochastic, ob):
|
||||
ac1, vpred1 = self._act(stochastic, ob[None])
|
||||
return ac1[0], vpred1[0]
|
||||
def get_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
|
||||
def get_trainable_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
|
||||
def get_initial_state(self):
|
||||
return []
|
||||
|
218
baselines/pposgd/pposgd_simple.py
Normal file
218
baselines/pposgd/pposgd_simple.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from baselines.common import Dataset, explained_variance, fmt_row, zipsame
|
||||
from baselines import logger
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf, numpy as np
|
||||
import time
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from baselines.common.mpi_moments import mpi_moments
|
||||
from mpi4py import MPI
|
||||
from collections import deque
|
||||
|
||||
def traj_segment_generator(pi, env, horizon, stochastic):
|
||||
t = 0
|
||||
ac = env.action_space.sample() # not used, just so we have the datatype
|
||||
new = True # marks if we're on first timestep of an episode
|
||||
ob = env.reset()
|
||||
|
||||
cur_ep_ret = 0 # return in current episode
|
||||
cur_ep_len = 0 # len of current episode
|
||||
ep_rets = [] # returns of completed episodes in this segment
|
||||
ep_lens = [] # lengths of ...
|
||||
|
||||
# Initialize history arrays
|
||||
obs = np.array([ob for _ in range(horizon)])
|
||||
rews = np.zeros(horizon, 'float32')
|
||||
vpreds = np.zeros(horizon, 'float32')
|
||||
news = np.zeros(horizon, 'int32')
|
||||
acs = np.array([ac for _ in range(horizon)])
|
||||
prevacs = acs.copy()
|
||||
|
||||
while True:
|
||||
prevac = ac
|
||||
ac, vpred = pi.act(stochastic, ob)
|
||||
# Slight weirdness here because we need value function at time T
|
||||
# before returning segment [0, T-1] so we get the correct
|
||||
# terminal value
|
||||
if t > 0 and t % horizon == 0:
|
||||
yield {"ob" : obs, "rew" : rews, "vpred" : vpreds, "new" : news,
|
||||
"ac" : acs, "prevac" : prevacs, "nextvpred": vpred * (1 - new),
|
||||
"ep_rets" : ep_rets, "ep_lens" : ep_lens}
|
||||
# Be careful!!! if you change the downstream algorithm to aggregate
|
||||
# several of these batches, then be sure to do a deepcopy
|
||||
ep_rets = []
|
||||
ep_lens = []
|
||||
i = t % horizon
|
||||
obs[i] = ob
|
||||
vpreds[i] = vpred
|
||||
news[i] = new
|
||||
acs[i] = ac
|
||||
prevacs[i] = prevac
|
||||
|
||||
ob, rew, new, _ = env.step(ac)
|
||||
rews[i] = rew
|
||||
|
||||
cur_ep_ret += rew
|
||||
cur_ep_len += 1
|
||||
if new:
|
||||
ep_rets.append(cur_ep_ret)
|
||||
ep_lens.append(cur_ep_len)
|
||||
cur_ep_ret = 0
|
||||
cur_ep_len = 0
|
||||
ob = env.reset()
|
||||
t += 1
|
||||
|
||||
def add_vtarg_and_adv(seg, gamma, lam):
|
||||
"""
|
||||
Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
|
||||
"""
|
||||
new = np.append(seg["new"], 0) # last element is only used for last vtarg, but we already zeroed it if last new = 1
|
||||
vpred = np.append(seg["vpred"], seg["nextvpred"])
|
||||
T = len(seg["rew"])
|
||||
seg["adv"] = gaelam = np.empty(T, 'float32')
|
||||
rew = seg["rew"]
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(T)):
|
||||
nonterminal = 1-new[t+1]
|
||||
delta = rew[t] + gamma * vpred[t+1] * nonterminal - vpred[t]
|
||||
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
||||
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|
||||
|
||||
def learn(env, policy_func, *,
|
||||
timesteps_per_batch, # timesteps per actor per update
|
||||
clip_param, entcoeff, # clipping parameter epsilon, entropy coeff
|
||||
optim_epochs, optim_stepsize, optim_batchsize,# optimization hypers
|
||||
gamma, lam, # advantage estimation
|
||||
max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0, # time constraint
|
||||
callback=None, # you can do anything in the callback, since it takes locals(), globals()
|
||||
schedule='constant' # annealing for stepsize parameters (epsilon and adam)
|
||||
):
|
||||
# Setup losses and stuff
|
||||
# ----------------------------------------
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
|
||||
oldpi = policy_func("oldpi", ob_space, ac_space) # Network for old policy
|
||||
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
|
||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||
|
||||
lrmult = tf.placeholder(name='lrmult', dtype=tf.float32, shape=[]) # learning rate multiplier, updated with schedule
|
||||
clip_param = clip_param * lrmult # Annealed cliping parameter epislon
|
||||
|
||||
ob = U.get_placeholder_cached(name="ob")
|
||||
ac = pi.pdtype.sample_placeholder([None])
|
||||
|
||||
kloldnew = oldpi.pd.kl(pi.pd)
|
||||
ent = pi.pd.entropy()
|
||||
meankl = U.mean(kloldnew)
|
||||
meanent = U.mean(ent)
|
||||
pol_entpen = (-entcoeff) * meanent
|
||||
|
||||
ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # pnew / pold
|
||||
surr1 = ratio * atarg # surrogate from conservative policy iteration
|
||||
surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
|
||||
pol_surr = - U.mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
|
||||
vfloss1 = tf.square(pi.vpred - ret)
|
||||
vpredclipped = oldpi.vpred + tf.clip_by_value(pi.vpred - oldpi.vpred, -clip_param, clip_param)
|
||||
vfloss2 = tf.square(vpredclipped - ret)
|
||||
vf_loss = .5 * U.mean(tf.maximum(vfloss1, vfloss2)) # we do the same clipping-based trust region for the value function
|
||||
total_loss = pol_surr + pol_entpen + vf_loss
|
||||
losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
|
||||
loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]
|
||||
|
||||
var_list = pi.get_trainable_variables()
|
||||
lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
|
||||
adam = MpiAdam(var_list)
|
||||
|
||||
assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
|
||||
for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
|
||||
compute_losses = U.function([ob, ac, atarg, ret, lrmult], losses)
|
||||
|
||||
U.initialize()
|
||||
adam.sync()
|
||||
|
||||
# Prepare for rollouts
|
||||
# ----------------------------------------
|
||||
seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)
|
||||
|
||||
episodes_so_far = 0
|
||||
timesteps_so_far = 0
|
||||
iters_so_far = 0
|
||||
tstart = time.time()
|
||||
lenbuffer = deque(maxlen=100) # rolling buffer for episode lengths
|
||||
rewbuffer = deque(maxlen=100) # rolling buffer for episode rewards
|
||||
|
||||
assert sum([max_iters>0, max_timesteps>0, max_episodes>0, max_seconds>0])==1, "Only one time constraint permitted"
|
||||
|
||||
while True:
|
||||
if callback: callback(locals(), globals())
|
||||
if max_timesteps and timesteps_so_far >= max_timesteps:
|
||||
break
|
||||
elif max_episodes and episodes_so_far >= max_episodes:
|
||||
break
|
||||
elif max_iters and iters_so_far >= max_iters:
|
||||
break
|
||||
elif max_seconds and time.time() - tstart >= max_seconds:
|
||||
break
|
||||
|
||||
if schedule == 'constant':
|
||||
cur_lrmult = 1.0
|
||||
elif schedule == 'linear':
|
||||
cur_lrmult = max(1.0 - float(timesteps_so_far) / max_timesteps, 0)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
logger.log("********** Iteration %i ************"%iters_so_far)
|
||||
|
||||
seg = seg_gen.__next__()
|
||||
add_vtarg_and_adv(seg, gamma, lam)
|
||||
|
||||
# ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
|
||||
ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
|
||||
vpredbefore = seg["vpred"] # predicted value function before udpate
|
||||
atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
|
||||
d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret), shuffle=not pi.recurrent)
|
||||
optim_batchsize = optim_batchsize or ob.shape[0]
|
||||
|
||||
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
|
||||
|
||||
assign_old_eq_new() # set old parameter values to new parameter values
|
||||
logger.log("Optimizing...")
|
||||
logger.log(fmt_row(13, loss_names))
|
||||
# Here we do a bunch of optimization epochs over the data
|
||||
for _ in range(optim_epochs):
|
||||
losses = [] # list of tuples, each of which gives the loss for a minibatch
|
||||
for batch in d.iterate_once(optim_batchsize):
|
||||
*newlosses, g = lossandgrad(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
|
||||
adam.update(g, optim_stepsize * cur_lrmult)
|
||||
losses.append(newlosses)
|
||||
logger.log(fmt_row(13, np.mean(losses, axis=0)))
|
||||
|
||||
logger.log("Evaluating losses...")
|
||||
losses = []
|
||||
for batch in d.iterate_once(optim_batchsize):
|
||||
newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"], cur_lrmult)
|
||||
losses.append(newlosses)
|
||||
meanlosses,_,_ = mpi_moments(losses, axis=0)
|
||||
logger.log(fmt_row(13, meanlosses))
|
||||
for (lossval, name) in zipsame(meanlosses, loss_names):
|
||||
logger.record_tabular("loss_"+name, lossval)
|
||||
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
|
||||
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
|
||||
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
||||
lens, rews = map(flatten_lists, zip(*listoflrpairs))
|
||||
lenbuffer.extend(lens)
|
||||
rewbuffer.extend(rews)
|
||||
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
|
||||
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
|
||||
logger.record_tabular("EpThisIter", len(lens))
|
||||
episodes_so_far += len(lens)
|
||||
timesteps_so_far += sum(lens)
|
||||
iters_so_far += 1
|
||||
logger.record_tabular("EpisodesSoFar", episodes_so_far)
|
||||
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
|
||||
logger.record_tabular("TimeElapsed", time.time() - tstart)
|
||||
if MPI.COMM_WORLD.Get_rank()==0:
|
||||
logger.dump_tabular()
|
||||
|
||||
def flatten_lists(listoflists):
|
||||
return [el for list_ in listoflists for el in list_]
|
54
baselines/pposgd/run_atari.py
Normal file
54
baselines/pposgd/run_atari.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
from mpi4py import MPI
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines import bench
|
||||
from baselines.common.mpi_fork import mpi_fork
|
||||
import os.path as osp
|
||||
import gym, logging
|
||||
from baselines import logger
|
||||
import sys
|
||||
|
||||
def wrap_train(env):
|
||||
from baselines.common.atari_wrappers import (wrap_deepmind, FrameStack)
|
||||
env = wrap_deepmind(env, clip_rewards=True)
|
||||
env = FrameStack(env, 4)
|
||||
return env
|
||||
|
||||
def train(env_id, num_timesteps, seed, num_cpu):
|
||||
from baselines.pposgd import pposgd_simple, cnn_policy
|
||||
import baselines.common.tf_util as U
|
||||
whoami = mpi_fork(num_cpu)
|
||||
if whoami == "parent": return
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
sess = U.single_threaded_session()
|
||||
sess.__enter__()
|
||||
logger.session().__enter__()
|
||||
if rank != 0: logger.set_level(logger.DISABLED)
|
||||
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
|
||||
set_global_seeds(workerseed)
|
||||
env = gym.make(env_id)
|
||||
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
|
||||
return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
|
||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
|
||||
env.seed(workerseed)
|
||||
gym.logger.setLevel(logging.WARN)
|
||||
|
||||
env = wrap_train(env)
|
||||
num_timesteps /= 4 # because we're wrapping the envs to do frame skip
|
||||
env.seed(workerseed)
|
||||
|
||||
pposgd_simple.learn(env, policy_fn,
|
||||
max_timesteps=num_timesteps,
|
||||
timesteps_per_batch=256,
|
||||
clip_param=0.2, entcoeff=0.01,
|
||||
optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
|
||||
gamma=0.99, lam=0.95,
|
||||
schedule='linear'
|
||||
)
|
||||
env.close()
|
||||
|
||||
def main():
|
||||
train('PongNoFrameskip-v4', num_timesteps=40e6, seed=0, num_cpu=8)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
35
baselines/pposgd/run_mujoco.py
Normal file
35
baselines/pposgd/run_mujoco.py
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
from baselines.common import set_global_seeds, tf_util as U
|
||||
from baselines import bench
|
||||
import os.path as osp
|
||||
import gym, logging
|
||||
from baselines import logger
|
||||
import sys
|
||||
|
||||
def train(env_id, num_timesteps, seed):
|
||||
from baselines.pposgd import mlp_policy, pposgd_simple
|
||||
U.make_session(num_cpu=1).__enter__()
|
||||
logger.session().__enter__()
|
||||
set_global_seeds(seed)
|
||||
env = gym.make(env_id)
|
||||
def policy_fn(name, ob_space, ac_space):
|
||||
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
|
||||
hid_size=64, num_hid_layers=2)
|
||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "monitor.json"))
|
||||
env.seed(seed)
|
||||
gym.logger.setLevel(logging.WARN)
|
||||
pposgd_simple.learn(env, policy_fn,
|
||||
max_timesteps=num_timesteps,
|
||||
timesteps_per_batch=2048,
|
||||
clip_param=0.2, entcoeff=0.0,
|
||||
optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
|
||||
gamma=0.99, lam=0.95,
|
||||
)
|
||||
env.close()
|
||||
|
||||
def main():
|
||||
train('Hopper-v1', num_timesteps=1e6, seed=0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
57
baselines/trpo_mpi/nosharing_cnn_policy.py
Normal file
57
baselines/trpo_mpi/nosharing_cnn_policy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
from baselines.common.distributions import make_pdtype
|
||||
|
||||
class CnnPolicy(object):
|
||||
recurrent = False
|
||||
def __init__(self, name, ob_space, ac_space):
|
||||
with tf.variable_scope(name):
|
||||
self._init(ob_space, ac_space)
|
||||
self.scope = tf.get_variable_scope().name
|
||||
|
||||
def _init(self, ob_space, ac_space):
|
||||
assert isinstance(ob_space, gym.spaces.Box)
|
||||
|
||||
self.pdtype = pdtype = make_pdtype(ac_space)
|
||||
sequence_length = None
|
||||
|
||||
ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
|
||||
|
||||
obscaled = ob / 255.0
|
||||
|
||||
with tf.variable_scope("pol"):
|
||||
x = obscaled
|
||||
x = tf.nn.relu(U.conv2d(x, 8, "l1", [8, 8], [4, 4], pad="VALID"))
|
||||
x = tf.nn.relu(U.conv2d(x, 16, "l2", [4, 4], [2, 2], pad="VALID"))
|
||||
x = U.flattenallbut0(x)
|
||||
x = tf.nn.relu(U.dense(x, 128, 'lin', U.normc_initializer(1.0)))
|
||||
logits = U.dense(x, pdtype.param_shape()[0], "logits", U.normc_initializer(0.01))
|
||||
self.pd = pdtype.pdfromflat(logits)
|
||||
with tf.variable_scope("vf"):
|
||||
x = obscaled
|
||||
x = tf.nn.relu(U.conv2d(x, 8, "l1", [8, 8], [4, 4], pad="VALID"))
|
||||
x = tf.nn.relu(U.conv2d(x, 16, "l2", [4, 4], [2, 2], pad="VALID"))
|
||||
x = U.flattenallbut0(x)
|
||||
x = tf.nn.relu(U.dense(x, 128, 'lin', U.normc_initializer(1.0)))
|
||||
self.vpred = U.dense(x, 1, "value", U.normc_initializer(1.0))
|
||||
self.vpredz = self.vpred
|
||||
|
||||
self.state_in = []
|
||||
self.state_out = []
|
||||
|
||||
stochastic = tf.placeholder(dtype=tf.bool, shape=())
|
||||
ac = self.pd.sample() # XXX
|
||||
self._act = U.function([stochastic, ob], [ac, self.vpred])
|
||||
|
||||
def act(self, stochastic, ob):
|
||||
ac1, vpred1 = self._act(stochastic, ob[None])
|
||||
return ac1[0], vpred1[0]
|
||||
def get_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope)
|
||||
def get_trainable_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
|
||||
def get_initial_state(self):
|
||||
return []
|
||||
|
53
baselines/trpo_mpi/run_atari.py
Normal file
53
baselines/trpo_mpi/run_atari.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python
|
||||
from mpi4py import MPI
|
||||
from baselines.common import set_global_seeds
|
||||
import os.path as osp
|
||||
import gym, logging
|
||||
from baselines import logger
|
||||
from baselines import bench
|
||||
from baselines.common.mpi_fork import mpi_fork
|
||||
import sys
|
||||
|
||||
def wrap_train(env):
|
||||
from baselines.common.atari_wrappers import (wrap_deepmind, FrameStack)
|
||||
env = wrap_deepmind(env, clip_rewards=False)
|
||||
env = FrameStack(env, 3)
|
||||
return env
|
||||
|
||||
def train(env_id, num_timesteps, seed, num_cpu):
|
||||
from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
|
||||
from baselines.trpo_mpi import trpo_mpi
|
||||
import baselines.common.tf_util as U
|
||||
whoami = mpi_fork(num_cpu)
|
||||
if whoami == "parent":
|
||||
return
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
sess = U.single_threaded_session()
|
||||
sess.__enter__()
|
||||
logger.session().__enter__()
|
||||
if rank != 0:
|
||||
logger.set_level(logger.DISABLED)
|
||||
|
||||
|
||||
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
|
||||
set_global_seeds(workerseed)
|
||||
env = gym.make(env_id)
|
||||
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
|
||||
return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
|
||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json"%rank))
|
||||
env.seed(workerseed)
|
||||
gym.logger.setLevel(logging.WARN)
|
||||
|
||||
env = wrap_train(env)
|
||||
num_timesteps /= 4 # because we're wrapping the envs to do frame skip
|
||||
env.seed(workerseed)
|
||||
|
||||
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
|
||||
max_timesteps=num_timesteps, gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00)
|
||||
env.close()
|
||||
|
||||
def main():
|
||||
train('PongNoFrameskip-v4', num_timesteps=40e6, seed=0, num_cpu=8)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
47
baselines/trpo_mpi/run_mujoco.py
Normal file
47
baselines/trpo_mpi/run_mujoco.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
# noinspection PyUnresolvedReferences
|
||||
import mujoco_py # Mujoco must come before other imports. https://openai.slack.com/archives/C1H6P3R7B/p1492828680631850
|
||||
from mpi4py import MPI
|
||||
from baselines.common import set_global_seeds
|
||||
import os.path as osp
|
||||
import gym
|
||||
import logging
|
||||
from baselines import logger
|
||||
from baselines.pposgd.mlp_policy import MlpPolicy
|
||||
from baselines.common.mpi_fork import mpi_fork
|
||||
from baselines import bench
|
||||
from baselines.trpo_mpi import trpo_mpi
|
||||
import sys
|
||||
num_cpu=1
|
||||
|
||||
def train(env_id, num_timesteps, seed):
|
||||
whoami = mpi_fork(num_cpu)
|
||||
if whoami == "parent":
|
||||
return
|
||||
import baselines.common.tf_util as U
|
||||
logger.session().__enter__()
|
||||
sess = U.single_threaded_session()
|
||||
sess.__enter__()
|
||||
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
if rank != 0:
|
||||
logger.set_level(logger.DISABLED)
|
||||
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
|
||||
set_global_seeds(workerseed)
|
||||
env = gym.make(env_id)
|
||||
def policy_fn(name, ob_space, ac_space):
|
||||
return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
|
||||
hid_size=32, num_hid_layers=2)
|
||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
|
||||
env.seed(workerseed)
|
||||
gym.logger.setLevel(logging.WARN)
|
||||
|
||||
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
|
||||
max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
|
||||
env.close()
|
||||
|
||||
def main():
|
||||
train('Hopper-v1', num_timesteps=1e6, seed=0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
291
baselines/trpo_mpi/trpo_mpi.py
Normal file
291
baselines/trpo_mpi/trpo_mpi.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from baselines.common import explained_variance, zipsame, dataset
|
||||
from baselines import logger
|
||||
import baselines.common.tf_util as U
|
||||
import tensorflow as tf, numpy as np
|
||||
import time
|
||||
from baselines.common import colorize
|
||||
from mpi4py import MPI
|
||||
from collections import deque
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from baselines.common.cg import cg
|
||||
from contextlib import contextmanager
|
||||
|
||||
def traj_segment_generator(pi, env, horizon, stochastic):
|
||||
# Initialize state variables
|
||||
t = 0
|
||||
ac = env.action_space.sample()
|
||||
new = True
|
||||
rew = 0.0
|
||||
ob = env.reset()
|
||||
|
||||
cur_ep_ret = 0
|
||||
cur_ep_len = 0
|
||||
ep_rets = []
|
||||
ep_lens = []
|
||||
|
||||
# Initialize history arrays
|
||||
obs = np.array([ob for _ in range(horizon)])
|
||||
rews = np.zeros(horizon, 'float32')
|
||||
vpreds = np.zeros(horizon, 'float32')
|
||||
news = np.zeros(horizon, 'int32')
|
||||
acs = np.array([ac for _ in range(horizon)])
|
||||
prevacs = acs.copy()
|
||||
|
||||
while True:
|
||||
prevac = ac
|
||||
ac, vpred = pi.act(stochastic, ob)
|
||||
# Slight weirdness here because we need value function at time T
|
||||
# before returning segment [0, T-1] so we get the correct
|
||||
# terminal value
|
||||
if t > 0 and t % horizon == 0:
|
||||
yield {"ob" : obs, "rew" : rews, "vpred" : vpreds, "new" : news,
|
||||
"ac" : acs, "prevac" : prevacs, "nextvpred": vpred * (1 - new),
|
||||
"ep_rets" : ep_rets, "ep_lens" : ep_lens}
|
||||
_, vpred = pi.act(stochastic, ob)
|
||||
# Be careful!!! if you change the downstream algorithm to aggregate
|
||||
# several of these batches, then be sure to do a deepcopy
|
||||
ep_rets = []
|
||||
ep_lens = []
|
||||
i = t % horizon
|
||||
obs[i] = ob
|
||||
vpreds[i] = vpred
|
||||
news[i] = new
|
||||
acs[i] = ac
|
||||
prevacs[i] = prevac
|
||||
|
||||
ob, rew, new, _ = env.step(ac)
|
||||
rews[i] = rew
|
||||
|
||||
cur_ep_ret += rew
|
||||
cur_ep_len += 1
|
||||
if new:
|
||||
ep_rets.append(cur_ep_ret)
|
||||
ep_lens.append(cur_ep_len)
|
||||
cur_ep_ret = 0
|
||||
cur_ep_len = 0
|
||||
ob = env.reset()
|
||||
t += 1
|
||||
|
||||
def add_vtarg_and_adv(seg, gamma, lam):
|
||||
new = np.append(seg["new"], 0) # last element is only used for last vtarg, but we already zeroed it if last new = 1
|
||||
vpred = np.append(seg["vpred"], seg["nextvpred"])
|
||||
T = len(seg["rew"])
|
||||
seg["adv"] = gaelam = np.empty(T, 'float32')
|
||||
rew = seg["rew"]
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(T)):
|
||||
nonterminal = 1-new[t+1]
|
||||
delta = rew[t] + gamma * vpred[t+1] * nonterminal - vpred[t]
|
||||
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
||||
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|
||||
|
||||
def learn(env, policy_func, *,
|
||||
timesteps_per_batch, # what to train on
|
||||
max_kl, cg_iters,
|
||||
gamma, lam, # advantage estimation
|
||||
entcoeff=0.0,
|
||||
cg_damping=1e-2,
|
||||
vf_stepsize=3e-4,
|
||||
vf_iters =3,
|
||||
max_timesteps=0, max_episodes=0, max_iters=0, # time constraint
|
||||
callback=None
|
||||
):
|
||||
nworkers = MPI.COMM_WORLD.Get_size()
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
np.set_printoptions(precision=3)
|
||||
# Setup losses and stuff
|
||||
# ----------------------------------------
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
pi = policy_func("pi", ob_space, ac_space)
|
||||
oldpi = policy_func("oldpi", ob_space, ac_space)
|
||||
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
|
||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||
|
||||
ob = U.get_placeholder_cached(name="ob")
|
||||
ac = pi.pdtype.sample_placeholder([None])
|
||||
|
||||
kloldnew = oldpi.pd.kl(pi.pd)
|
||||
ent = pi.pd.entropy()
|
||||
meankl = U.mean(kloldnew)
|
||||
meanent = U.mean(ent)
|
||||
entbonus = entcoeff * meanent
|
||||
|
||||
vferr = U.mean(tf.square(pi.vpred - ret))
|
||||
|
||||
ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
|
||||
surrgain = U.mean(ratio * atarg)
|
||||
|
||||
optimgain = surrgain + entbonus
|
||||
losses = [optimgain, meankl, entbonus, surrgain, meanent]
|
||||
loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]
|
||||
|
||||
dist = meankl
|
||||
|
||||
all_var_list = pi.get_trainable_variables()
|
||||
var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
|
||||
vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
|
||||
vfadam = MpiAdam(vf_var_list)
|
||||
|
||||
get_flat = U.GetFlat(var_list)
|
||||
set_from_flat = U.SetFromFlat(var_list)
|
||||
klgrads = tf.gradients(dist, var_list)
|
||||
flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
|
||||
shapes = [var.get_shape().as_list() for var in var_list]
|
||||
start = 0
|
||||
tangents = []
|
||||
for shape in shapes:
|
||||
sz = U.intprod(shape)
|
||||
tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
|
||||
start += sz
|
||||
gvp = tf.add_n([U.sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111
|
||||
fvp = U.flatgrad(gvp, var_list)
|
||||
|
||||
assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
|
||||
for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
|
||||
compute_losses = U.function([ob, ac, atarg], losses)
|
||||
compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
|
||||
compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
|
||||
compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))
|
||||
|
||||
@contextmanager
|
||||
def timed(msg):
|
||||
if rank == 0:
|
||||
print(colorize(msg, color='magenta'))
|
||||
tstart = time.time()
|
||||
yield
|
||||
print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
|
||||
else:
|
||||
yield
|
||||
|
||||
def allmean(x):
|
||||
assert isinstance(x, np.ndarray)
|
||||
out = np.empty_like(x)
|
||||
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
|
||||
out /= nworkers
|
||||
return out
|
||||
|
||||
U.initialize()
|
||||
th_init = get_flat()
|
||||
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
||||
set_from_flat(th_init)
|
||||
vfadam.sync()
|
||||
print("Init param sum", th_init.sum(), flush=True)
|
||||
|
||||
# Prepare for rollouts
|
||||
# ----------------------------------------
|
||||
seg_gen = traj_segment_generator(pi, env, timesteps_per_batch, stochastic=True)
|
||||
|
||||
episodes_so_far = 0
|
||||
timesteps_so_far = 0
|
||||
iters_so_far = 0
|
||||
tstart = time.time()
|
||||
lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
|
||||
rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards
|
||||
|
||||
assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1
|
||||
|
||||
while True:
|
||||
if callback: callback(locals(), globals())
|
||||
if max_timesteps and timesteps_so_far >= max_timesteps:
|
||||
break
|
||||
elif max_episodes and episodes_so_far >= max_episodes:
|
||||
break
|
||||
elif max_iters and iters_so_far >= max_iters:
|
||||
break
|
||||
logger.log("********** Iteration %i ************"%iters_so_far)
|
||||
|
||||
with timed("sampling"):
|
||||
seg = seg_gen.__next__()
|
||||
add_vtarg_and_adv(seg, gamma, lam)
|
||||
|
||||
# ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
|
||||
ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
|
||||
vpredbefore = seg["vpred"] # predicted value function before udpate
|
||||
atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
|
||||
|
||||
if hasattr(pi, "ret_rms"): pi.ret_rms.update(tdlamret)
|
||||
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
|
||||
|
||||
args = seg["ob"], seg["ac"], seg["adv"]
|
||||
fvpargs = [arr[::5] for arr in args]
|
||||
def fisher_vector_product(p):
|
||||
return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p
|
||||
|
||||
assign_old_eq_new() # set old parameter values to new parameter values
|
||||
with timed("computegrad"):
|
||||
*lossbefore, g = compute_lossandgrad(*args)
|
||||
lossbefore = allmean(np.array(lossbefore))
|
||||
g = allmean(g)
|
||||
if np.allclose(g, 0):
|
||||
logger.log("Got zero gradient. not updating")
|
||||
else:
|
||||
with timed("cg"):
|
||||
stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0)
|
||||
assert np.isfinite(stepdir).all()
|
||||
shs = .5*stepdir.dot(fisher_vector_product(stepdir))
|
||||
lm = np.sqrt(shs / max_kl)
|
||||
# logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
|
||||
fullstep = stepdir / lm
|
||||
expectedimprove = g.dot(fullstep)
|
||||
surrbefore = lossbefore[0]
|
||||
stepsize = 1.0
|
||||
thbefore = get_flat()
|
||||
for _ in range(10):
|
||||
thnew = thbefore + fullstep * stepsize
|
||||
set_from_flat(thnew)
|
||||
meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
|
||||
improve = surr - surrbefore
|
||||
logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve))
|
||||
if not np.isfinite(meanlosses).all():
|
||||
logger.log("Got non-finite value of losses -- bad!")
|
||||
elif kl > max_kl * 1.5:
|
||||
logger.log("violated KL constraint. shrinking step.")
|
||||
elif improve < 0:
|
||||
logger.log("surrogate didn't improve. shrinking step.")
|
||||
else:
|
||||
logger.log("Stepsize OK!")
|
||||
break
|
||||
stepsize *= .5
|
||||
else:
|
||||
logger.log("couldn't compute a good step")
|
||||
set_from_flat(thbefore)
|
||||
if nworkers > 1 and iters_so_far % 20 == 0:
|
||||
paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
|
||||
assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
|
||||
|
||||
for (lossname, lossval) in zip(loss_names, meanlosses):
|
||||
logger.record_tabular(lossname, lossval)
|
||||
|
||||
with timed("vf"):
|
||||
|
||||
for _ in range(vf_iters):
|
||||
for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
|
||||
include_final_partial_batch=False, batch_size=64):
|
||||
g = allmean(compute_vflossandgrad(mbob, mbret))
|
||||
vfadam.update(g, vf_stepsize)
|
||||
|
||||
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
|
||||
|
||||
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
|
||||
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
||||
lens, rews = map(flatten_lists, zip(*listoflrpairs))
|
||||
lenbuffer.extend(lens)
|
||||
rewbuffer.extend(rews)
|
||||
|
||||
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
|
||||
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
|
||||
logger.record_tabular("EpThisIter", len(lens))
|
||||
episodes_so_far += len(lens)
|
||||
timesteps_so_far += sum(lens)
|
||||
iters_so_far += 1
|
||||
|
||||
logger.record_tabular("EpisodesSoFar", episodes_so_far)
|
||||
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
|
||||
logger.record_tabular("TimeElapsed", time.time() - tstart)
|
||||
|
||||
if rank==0:
|
||||
logger.dump_tabular()
|
||||
|
||||
def flatten_lists(listoflists):
|
||||
return [el for list_ in listoflists for el in list_]
|
1
setup.py
1
setup.py
@@ -5,6 +5,7 @@ if sys.version_info.major != 3:
|
||||
print("This Python is only compatible with Python 3, but you are running "
|
||||
"Python {}. The installation will likely fail.".format(sys.version_info.major))
|
||||
|
||||
|
||||
setup(name='baselines',
|
||||
packages=[package for package in find_packages()
|
||||
if package.startswith('baselines')],
|
||||
|
Reference in New Issue
Block a user