Merge branch 'master' of github.com:openai/baselines into internal
This commit is contained in:
@@ -89,7 +89,7 @@ python -m baselines.run --alg=ppo2 --env=Humanoid-v2 --network=mlp --num_timeste
|
|||||||
will set entropy coefficient to 0.1, and construct fully connected network with 3 layers with 32 hidden units in each, and create a separate network for value function estimation (so that its parameters are not shared with the policy network, but the structure is the same)
|
will set entropy coefficient to 0.1, and construct fully connected network with 3 layers with 32 hidden units in each, and create a separate network for value function estimation (so that its parameters are not shared with the policy network, but the structure is the same)
|
||||||
|
|
||||||
See docstrings in [common/models.py](baselines/common/models.py) for description of network parameters for each type of model, and
|
See docstrings in [common/models.py](baselines/common/models.py) for description of network parameters for each type of model, and
|
||||||
docstring for [baselines/ppo2/ppo2.py/learn()](baselines/ppo2/ppo2.py#L152) for the description of the ppo2 hyperparamters.
|
docstring for [baselines/ppo2/ppo2.py/learn()](baselines/ppo2/ppo2.py#L152) for the description of the ppo2 hyperparameters.
|
||||||
|
|
||||||
### Example 2. DQN on Atari
|
### Example 2. DQN on Atari
|
||||||
DQN with Atari is at this point a classics of benchmarks. To run the baselines implementation of DQN on Atari Pong:
|
DQN with Atari is at this point a classics of benchmarks. To run the baselines implementation of DQN on Atari Pong:
|
||||||
|
@@ -11,6 +11,8 @@ from baselines.common.policies import build_policy
|
|||||||
|
|
||||||
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
||||||
from baselines.a2c.runner import Runner
|
from baselines.a2c.runner import Runner
|
||||||
|
from baselines.ppo2.ppo2 import safemean
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
from tensorflow import losses
|
from tensorflow import losses
|
||||||
|
|
||||||
@@ -195,6 +197,7 @@ def learn(
|
|||||||
|
|
||||||
# Instantiate the runner object
|
# Instantiate the runner object
|
||||||
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
||||||
|
epinfobuf = deque(maxlen=100)
|
||||||
|
|
||||||
# Calculate the batch_size
|
# Calculate the batch_size
|
||||||
nbatch = nenvs*nsteps
|
nbatch = nenvs*nsteps
|
||||||
@@ -204,7 +207,8 @@ def learn(
|
|||||||
|
|
||||||
for update in range(1, total_timesteps//nbatch+1):
|
for update in range(1, total_timesteps//nbatch+1):
|
||||||
# Get mini batch of experiences
|
# Get mini batch of experiences
|
||||||
obs, states, rewards, masks, actions, values = runner.run()
|
obs, states, rewards, masks, actions, values, epinfos = runner.run()
|
||||||
|
epinfobuf.extend(epinfos)
|
||||||
|
|
||||||
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
||||||
nseconds = time.time()-tstart
|
nseconds = time.time()-tstart
|
||||||
@@ -221,6 +225,8 @@ def learn(
|
|||||||
logger.record_tabular("policy_entropy", float(policy_entropy))
|
logger.record_tabular("policy_entropy", float(policy_entropy))
|
||||||
logger.record_tabular("value_loss", float(value_loss))
|
logger.record_tabular("value_loss", float(value_loss))
|
||||||
logger.record_tabular("explained_variance", float(ev))
|
logger.record_tabular("explained_variance", float(ev))
|
||||||
|
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
|
||||||
|
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
|
||||||
logger.dump_tabular()
|
logger.dump_tabular()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@@ -22,6 +22,7 @@ class Runner(AbstractEnvRunner):
|
|||||||
# We initialize the lists that will contain the mb of experiences
|
# We initialize the lists that will contain the mb of experiences
|
||||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
||||||
mb_states = self.states
|
mb_states = self.states
|
||||||
|
epinfos = []
|
||||||
for n in range(self.nsteps):
|
for n in range(self.nsteps):
|
||||||
# Given observations, take action and value (V(s))
|
# Given observations, take action and value (V(s))
|
||||||
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
# We already have self.obs because Runner superclass run self.obs[:] = env.reset() on init
|
||||||
@@ -34,7 +35,10 @@ class Runner(AbstractEnvRunner):
|
|||||||
mb_dones.append(self.dones)
|
mb_dones.append(self.dones)
|
||||||
|
|
||||||
# Take actions in env and look the results
|
# Take actions in env and look the results
|
||||||
obs, rewards, dones, _ = self.env.step(actions)
|
obs, rewards, dones, infos = self.env.step(actions)
|
||||||
|
for info in infos:
|
||||||
|
maybeepinfo = info.get('episode')
|
||||||
|
if maybeepinfo: epinfos.append(maybeepinfo)
|
||||||
self.states = states
|
self.states = states
|
||||||
self.dones = dones
|
self.dones = dones
|
||||||
self.obs = obs
|
self.obs = obs
|
||||||
@@ -69,4 +73,4 @@ class Runner(AbstractEnvRunner):
|
|||||||
mb_rewards = mb_rewards.flatten()
|
mb_rewards = mb_rewards.flatten()
|
||||||
mb_values = mb_values.flatten()
|
mb_values = mb_values.flatten()
|
||||||
mb_masks = mb_masks.flatten()
|
mb_masks = mb_masks.flatten()
|
||||||
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values
|
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, epinfos
|
||||||
|
@@ -11,6 +11,8 @@ from baselines.common.tf_util import get_session, save_variables, load_variables
|
|||||||
from baselines.a2c.runner import Runner
|
from baselines.a2c.runner import Runner
|
||||||
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
from baselines.a2c.utils import Scheduler, find_trainable_variables
|
||||||
from baselines.acktr import kfac
|
from baselines.acktr import kfac
|
||||||
|
from baselines.ppo2.ppo2 import safemean
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
@@ -118,6 +120,7 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
|
|||||||
model.load(load_path)
|
model.load(load_path)
|
||||||
|
|
||||||
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
|
||||||
|
epinfobuf = deque(maxlen=100)
|
||||||
nbatch = nenvs*nsteps
|
nbatch = nenvs*nsteps
|
||||||
tstart = time.time()
|
tstart = time.time()
|
||||||
coord = tf.train.Coordinator()
|
coord = tf.train.Coordinator()
|
||||||
@@ -127,7 +130,8 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
|
|||||||
enqueue_threads = []
|
enqueue_threads = []
|
||||||
|
|
||||||
for update in range(1, total_timesteps//nbatch+1):
|
for update in range(1, total_timesteps//nbatch+1):
|
||||||
obs, states, rewards, masks, actions, values = runner.run()
|
obs, states, rewards, masks, actions, values, epinfos = runner.run()
|
||||||
|
epinfobuf.extend(epinfos)
|
||||||
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
|
||||||
model.old_obs = obs
|
model.old_obs = obs
|
||||||
nseconds = time.time()-tstart
|
nseconds = time.time()-tstart
|
||||||
@@ -141,6 +145,8 @@ def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interva
|
|||||||
logger.record_tabular("policy_loss", float(policy_loss))
|
logger.record_tabular("policy_loss", float(policy_loss))
|
||||||
logger.record_tabular("value_loss", float(value_loss))
|
logger.record_tabular("value_loss", float(value_loss))
|
||||||
logger.record_tabular("explained_variance", float(ev))
|
logger.record_tabular("explained_variance", float(ev))
|
||||||
|
logger.record_tabular("eprewmean", safemean([epinfo['r'] for epinfo in epinfobuf]))
|
||||||
|
logger.record_tabular("eplenmean", safemean([epinfo['l'] for epinfo in epinfobuf]))
|
||||||
logger.dump_tabular()
|
logger.dump_tabular()
|
||||||
|
|
||||||
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
|
if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
|
||||||
|
@@ -87,6 +87,8 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
|
|||||||
if env_type == 'atari':
|
if env_type == 'atari':
|
||||||
env = wrap_deepmind(env, **wrapper_kwargs)
|
env = wrap_deepmind(env, **wrapper_kwargs)
|
||||||
elif env_type == 'retro':
|
elif env_type == 'retro':
|
||||||
|
if 'frame_stack' not in wrapper_kwargs:
|
||||||
|
wrapper_kwargs['frame_stack'] = 1
|
||||||
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
|
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
|
||||||
|
|
||||||
if isinstance(env.action_space, gym.spaces.Box):
|
if isinstance(env.action_space, gym.spaces.Box):
|
||||||
|
@@ -123,7 +123,7 @@ def mpi_weighted_mean(comm, local_name2valcount):
|
|||||||
val = float(val)
|
val = float(val)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if comm.rank == 0:
|
if comm.rank == 0:
|
||||||
warnings.warn(f'WARNING: tried to compute mean on non-float {name}={val}')
|
warnings.warn('WARNING: tried to compute mean on non-float {}={}'.format(name, val))
|
||||||
else:
|
else:
|
||||||
name2sum[name] += val * count
|
name2sum[name] += val * count
|
||||||
name2count[name] += count
|
name2count[name] += count
|
||||||
|
@@ -248,7 +248,7 @@ def plot_results(
|
|||||||
figsize=None,
|
figsize=None,
|
||||||
legend_outside=False,
|
legend_outside=False,
|
||||||
resample=0,
|
resample=0,
|
||||||
smooth_step=1.0,
|
smooth_step=1.0
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
Plot multiple Results objects
|
Plot multiple Results objects
|
||||||
|
@@ -177,7 +177,7 @@ def profile_tf_runningmeanstd():
|
|||||||
outfile = '/tmp/timeline.json'
|
outfile = '/tmp/timeline.json'
|
||||||
with open(outfile, 'wt') as f:
|
with open(outfile, 'wt') as f:
|
||||||
f.write(chrome_trace)
|
f.write(chrome_trace)
|
||||||
print(f'Successfully saved profile to {outfile}. Exiting.')
|
print('Successfully saved profile to {}. Exiting.'.format(outfile))
|
||||||
exit(0)
|
exit(0)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@@ -16,7 +16,7 @@ def test_mpi_weighted_mean():
|
|||||||
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
|
d = mpi_util.mpi_weighted_mean(comm, name2valcount)
|
||||||
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
|
correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
|
||||||
if comm.rank == 0:
|
if comm.rank == 0:
|
||||||
assert d == correctval, f'{d} != {correctval}'
|
assert d == correctval, '{} != {}'.format(d, correctval)
|
||||||
|
|
||||||
for name, (val, count) in name2valcount.items():
|
for name, (val, count) in name2valcount.items():
|
||||||
for _ in range(count):
|
for _ in range(count):
|
||||||
|
@@ -305,12 +305,17 @@ def display_var_info(vars):
|
|||||||
logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))
|
logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))
|
||||||
|
|
||||||
|
|
||||||
def get_available_gpus():
|
def get_available_gpus(session_config=None):
|
||||||
# recipe from here:
|
# based on recipe from https://stackoverflow.com/a/38580201
|
||||||
# https://stackoverflow.com/questions/38559755/how-to-get-current-available-gpus-in-tensorflow?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
|
|
||||||
|
# Unless we allocate a session here, subsequent attempts to create one
|
||||||
|
# will ignore our custom config (in particular, allow_growth=True will have
|
||||||
|
# no effect).
|
||||||
|
if session_config is None:
|
||||||
|
session_config = get_session()._config
|
||||||
|
|
||||||
from tensorflow.python.client import device_lib
|
from tensorflow.python.client import device_lib
|
||||||
local_device_protos = device_lib.list_local_devices()
|
local_device_protos = device_lib.list_local_devices(session_config)
|
||||||
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
|
@@ -23,7 +23,7 @@ def model(inpt, num_actions, scope, reuse=False):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
with U.make_session(8):
|
with U.make_session(num_cpu=8):
|
||||||
# Create the environment
|
# Create the environment
|
||||||
env = gym.make("CartPole-v0")
|
env = gym.make("CartPole-v0")
|
||||||
# Create all the functions necessary to train the model
|
# Create all the functions necessary to train the model
|
||||||
|
@@ -20,7 +20,7 @@ class TfInput(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def make_feed_dict(data):
|
def make_feed_dict(self, data):
|
||||||
"""Given data input it to the placeholder(s)."""
|
"""Given data input it to the placeholder(s)."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@@ -12,13 +12,13 @@ Download the expert data into `./data`, [download link](https://drive.google.com
|
|||||||
|
|
||||||
### Step 2: Run GAIL
|
### Step 2: Run GAIL
|
||||||
|
|
||||||
Run with single thread:
|
Run with single rank:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m baselines.gail.run_mujoco
|
python -m baselines.gail.run_mujoco
|
||||||
```
|
```
|
||||||
|
|
||||||
Run with multiple threads:
|
Run with multiple ranks:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
mpirun -np 16 python -m baselines.gail.run_mujoco
|
mpirun -np 16 python -m baselines.gail.run_mujoco
|
||||||
|
@@ -66,7 +66,7 @@ class TransitionClassifier(object):
|
|||||||
|
|
||||||
with tf.variable_scope("obfilter"):
|
with tf.variable_scope("obfilter"):
|
||||||
self.obs_rms = RunningMeanStd(shape=self.observation_shape)
|
self.obs_rms = RunningMeanStd(shape=self.observation_shape)
|
||||||
obs = (obs_ph - self.obs_rms.mean / self.obs_rms.std)
|
obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std
|
||||||
_input = tf.concat([obs, acs_ph], axis=1) # concatenate the two input -> form a transition
|
_input = tf.concat([obs, acs_ph], axis=1) # concatenate the two input -> form a transition
|
||||||
p_h1 = tf.contrib.layers.fully_connected(_input, self.hidden_size, activation_fn=tf.nn.tanh)
|
p_h1 = tf.contrib.layers.fully_connected(_input, self.hidden_size, activation_fn=tf.nn.tanh)
|
||||||
p_h2 = tf.contrib.layers.fully_connected(p_h1, self.hidden_size, activation_fn=tf.nn.tanh)
|
p_h2 = tf.contrib.layers.fully_connected(p_h1, self.hidden_size, activation_fn=tf.nn.tanh)
|
||||||
|
@@ -50,8 +50,12 @@ class Mujoco_Dset(object):
|
|||||||
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
|
# obs, acs: shape (N, L, ) + S where N = # episodes, L = episode length
|
||||||
# and S is the environment observation/action space.
|
# and S is the environment observation/action space.
|
||||||
# Flatten to (N * L, prod(S))
|
# Flatten to (N * L, prod(S))
|
||||||
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
|
if len(obs.shape) > 2:
|
||||||
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
|
self.obs = np.reshape(obs, [-1, np.prod(obs.shape[2:])])
|
||||||
|
self.acs = np.reshape(acs, [-1, np.prod(acs.shape[2:])])
|
||||||
|
else:
|
||||||
|
self.obs = np.vstack(obs)
|
||||||
|
self.acs = np.vstack(acs)
|
||||||
|
|
||||||
self.rets = traj_data['ep_rets'][:traj_limitation]
|
self.rets = traj_data['ep_rets'][:traj_limitation]
|
||||||
self.avg_ret = sum(self.rets)/len(self.rets)
|
self.avg_ret = sum(self.rets)/len(self.rets)
|
||||||
|
@@ -119,13 +119,13 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
|||||||
eval_epinfobuf = deque(maxlen=100)
|
eval_epinfobuf = deque(maxlen=100)
|
||||||
|
|
||||||
# Start total timer
|
# Start total timer
|
||||||
tfirststart = time.time()
|
tfirststart = time.perf_counter()
|
||||||
|
|
||||||
nupdates = total_timesteps//nbatch
|
nupdates = total_timesteps//nbatch
|
||||||
for update in range(1, nupdates+1):
|
for update in range(1, nupdates+1):
|
||||||
assert nbatch % nminibatches == 0
|
assert nbatch % nminibatches == 0
|
||||||
# Start timer
|
# Start timer
|
||||||
tstart = time.time()
|
tstart = time.perf_counter()
|
||||||
frac = 1.0 - (update - 1.0) / nupdates
|
frac = 1.0 - (update - 1.0) / nupdates
|
||||||
# Calculate the learning rate
|
# Calculate the learning rate
|
||||||
lrnow = lr(frac)
|
lrnow = lr(frac)
|
||||||
@@ -173,7 +173,7 @@ def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2
|
|||||||
# Feedforward --> get losses --> update
|
# Feedforward --> get losses --> update
|
||||||
lossvals = np.mean(mblossvals, axis=0)
|
lossvals = np.mean(mblossvals, axis=0)
|
||||||
# End timer
|
# End timer
|
||||||
tnow = time.time()
|
tnow = time.perf_counter()
|
||||||
# Calculate the fps (frame per second)
|
# Calculate the fps (frame per second)
|
||||||
fps = int(nbatch / (tnow - tstart))
|
fps = int(nbatch / (tnow - tstart))
|
||||||
if update % log_interval == 0 or update == 1:
|
if update % log_interval == 0 or update == 1:
|
||||||
|
@@ -6,7 +6,7 @@ from collections import defaultdict
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from baselines.common.vec_env import VecFrameStack, VecNormalize
|
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
|
||||||
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||||
from baselines.common.tf_util import get_session
|
from baselines.common.tf_util import get_session
|
||||||
@@ -228,11 +228,11 @@ def main(args):
|
|||||||
actions, _, _, _ = model.step(obs)
|
actions, _, _, _ = model.step(obs)
|
||||||
|
|
||||||
obs, rew, done, _ = env.step(actions)
|
obs, rew, done, _ = env.step(actions)
|
||||||
episode_rew += rew[0]
|
episode_rew += rew[0] if isinstance(env, VecEnv) else rew
|
||||||
env.render()
|
env.render()
|
||||||
done = done.any() if isinstance(done, np.ndarray) else done
|
done = done.any() if isinstance(done, np.ndarray) else done
|
||||||
if done:
|
if done:
|
||||||
print(f'episode_rew={episode_rew}')
|
print('episode_rew={}'.format(episode_rew))
|
||||||
episode_rew = 0
|
episode_rew = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user