@@ -173,6 +173,15 @@ def huber_loss(x, delta=1.0):
|
||||
delta * (tf.abs(x) - 0.5 * delta)
|
||||
)
|
||||
|
||||
def logsigmoid(a):
|
||||
'''Equivalent to tf.log(tf.sigmoid(a))'''
|
||||
return -tf.nn.softplus(-a)
|
||||
|
||||
""" Reference: https://github.com/openai/imitation/blob/99fbccf3e060b6e6c739bdf209758620fcdefd3c/policyopt/thutil.py#L48-L51"""
|
||||
def logit_bernoulli_entropy(logits):
|
||||
ent = (1.-tf.nn.sigmoid(logits))*logits - logsigmoid(logits)
|
||||
return ent
|
||||
|
||||
# ================================================================
|
||||
# Optimizer utils
|
||||
# ================================================================
|
||||
@@ -232,17 +241,26 @@ def set_value(v, val):
|
||||
VALUE_SETTERS[v] = (set_op, set_endpoint)
|
||||
get_session().run(set_op, feed_dict={set_endpoint: val})
|
||||
|
||||
# ================================================================
|
||||
# Save tensorflow summary
|
||||
# ================================================================
|
||||
|
||||
def file_writer(dir_path):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
return tf.summary.FileWriter(dir_path, get_session().graph)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Saving variables
|
||||
# ================================================================
|
||||
|
||||
def load_state(fname):
|
||||
saver = tf.train.Saver()
|
||||
def load_state(fname, var_list=None):
|
||||
saver = tf.train.Saver(var_list=var_list)
|
||||
saver.restore(get_session(), fname)
|
||||
|
||||
def save_state(fname):
|
||||
def save_state(fname, var_list=None):
|
||||
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
||||
saver = tf.train.Saver()
|
||||
saver = tf.train.Saver(var_list=var_list)
|
||||
saver.save(get_session(), fname)
|
||||
|
||||
# ================================================================
|
||||
|
44
baselines/gail/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Generative Adversarial Imitation Learning (GAIL)
|
||||
|
||||
- Original paper: https://arxiv.org/abs/1606.03476
|
||||
|
||||
For results benchmarking on MuJoCo, please navigate to [here](result/gail-result.md)
|
||||
|
||||
## If you want to train an imitation learning agent
|
||||
|
||||
### Step 1: Download expert data
|
||||
|
||||
Download the expert data into `./data`, [download link](https://drive.google.com/drive/folders/1h3H4AY_ZBx08hz-Ct0Nxxus-V1melu1U?usp=sharing)
|
||||
|
||||
### Step 2: Run GAIL
|
||||
|
||||
Run with single thread:
|
||||
|
||||
```bash
|
||||
python -m baselines.gail.run_mujoco
|
||||
```
|
||||
|
||||
Run with multiple threads:
|
||||
|
||||
```bash
|
||||
mpirun -np 16 python -m baselines.gail.run_mujoco
|
||||
```
|
||||
|
||||
See help (`-h`) for more options.
|
||||
|
||||
#### In case you want to run Behavior Cloning (BC)
|
||||
|
||||
```bash
|
||||
python -m baselines.gail.behavior_clone
|
||||
```
|
||||
|
||||
See help (`-h`) for more options.
|
||||
|
||||
## Others
|
||||
|
||||
Thanks to the open source:
|
||||
|
||||
- @openai/imitation
|
||||
- @carpedm20/deep-rl-tensorflow
|
||||
|
||||
Also, thanks [Ryan Julian](https://github.com/ryanjulian) for reviewing the code
|
0
baselines/gail/__init__.py
Normal file
79
baselines/gail/adversary.py
Normal file
@@ -0,0 +1,79 @@
|
||||
'''
|
||||
Reference: https://github.com/openai/imitation
|
||||
I follow the architecture from the official repository
|
||||
'''
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||
from baselines.common import tf_util as U
|
||||
|
||||
|
||||
class TransitionClassifier(object):
|
||||
def __init__(self, env, hidden_size, entcoeff=0.001, lr_rate=1e-3, scope="adversary"):
|
||||
self.scope = scope
|
||||
self.observation_shape = env.observation_space.shape
|
||||
self.actions_shape = env.action_space.shape
|
||||
self.input_shape = tuple([o+a for o, a in zip(self.observation_shape, self.actions_shape)])
|
||||
self.num_actions = env.action_space.shape[0]
|
||||
self.hidden_size = hidden_size
|
||||
self.build_ph()
|
||||
# Build grpah
|
||||
generator_logits = self.build_graph(self.generator_obs_ph, self.generator_acs_ph, reuse=False)
|
||||
expert_logits = self.build_graph(self.expert_obs_ph, self.expert_acs_ph, reuse=True)
|
||||
# Build accuracy
|
||||
generator_acc = tf.reduce_mean(tf.to_float(tf.nn.sigmoid(generator_logits) < 0.5))
|
||||
expert_acc = tf.reduce_mean(tf.to_float(tf.nn.sigmoid(expert_logits) > 0.5))
|
||||
# Build regression loss
|
||||
# let x = logits, z = targets.
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
generator_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=generator_logits, labels=tf.zeros_like(generator_logits))
|
||||
generator_loss = tf.reduce_mean(generator_loss)
|
||||
expert_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=expert_logits, labels=tf.ones_like(expert_logits))
|
||||
expert_loss = tf.reduce_mean(expert_loss)
|
||||
# Build entropy loss
|
||||
logits = tf.concat([generator_logits, expert_logits], 0)
|
||||
entropy = tf.reduce_mean(U.logit_bernoulli_entropy(logits))
|
||||
entropy_loss = -entcoeff*entropy
|
||||
# Loss + Accuracy terms
|
||||
self.losses = [generator_loss, expert_loss, entropy, entropy_loss, generator_acc, expert_acc]
|
||||
self.loss_name = ["generator_loss", "expert_loss", "entropy", "entropy_loss", "generator_acc", "expert_acc"]
|
||||
self.total_loss = generator_loss + expert_loss + entropy_loss
|
||||
# Build Reward for policy
|
||||
self.reward_op = -tf.log(1-tf.nn.sigmoid(generator_logits)+1e-8)
|
||||
var_list = self.get_trainable_variables()
|
||||
self.lossandgrad = U.function([self.generator_obs_ph, self.generator_acs_ph, self.expert_obs_ph, self.expert_acs_ph],
|
||||
self.losses + [U.flatgrad(self.total_loss, var_list)])
|
||||
|
||||
def build_ph(self):
|
||||
self.generator_obs_ph = tf.placeholder(tf.float32, (None, ) + self.observation_shape, name="observations_ph")
|
||||
self.generator_acs_ph = tf.placeholder(tf.float32, (None, ) + self.actions_shape, name="actions_ph")
|
||||
self.expert_obs_ph = tf.placeholder(tf.float32, (None, ) + self.observation_shape, name="expert_observations_ph")
|
||||
self.expert_acs_ph = tf.placeholder(tf.float32, (None, ) + self.actions_shape, name="expert_actions_ph")
|
||||
|
||||
def build_graph(self, obs_ph, acs_ph, reuse=False):
|
||||
with tf.variable_scope(self.scope):
|
||||
if reuse:
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
|
||||
with tf.variable_scope("obfilter"):
|
||||
self.obs_rms = RunningMeanStd(shape=self.observation_shape)
|
||||
obs = (obs_ph - self.obs_rms.mean / self.obs_rms.std)
|
||||
_input = tf.concat([obs, acs_ph], axis=1) # concatenate the two input -> form a transition
|
||||
p_h1 = tf.contrib.layers.fully_connected(_input, self.hidden_size, activation_fn=tf.nn.tanh)
|
||||
p_h2 = tf.contrib.layers.fully_connected(p_h1, self.hidden_size, activation_fn=tf.nn.tanh)
|
||||
logits = tf.contrib.layers.fully_connected(p_h2, 1, activation_fn=tf.identity)
|
||||
return logits
|
||||
|
||||
def get_trainable_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
|
||||
|
||||
def get_reward(self, obs, acs):
|
||||
sess = U.get_session()
|
||||
if len(obs.shape) == 1:
|
||||
obs = np.expand_dims(obs, 0)
|
||||
if len(acs.shape) == 1:
|
||||
acs = np.expand_dims(acs, 0)
|
||||
feed_dict = {self.generator_obs_ph: obs, self.generator_acs_ph: acs}
|
||||
reward = sess.run(self.reward_op, feed_dict)
|
||||
return reward
|
124
baselines/gail/behavior_clone.py
Normal file
@@ -0,0 +1,124 @@
|
||||
'''
|
||||
The code is used to train BC imitator, or pretrained GAIL imitator
|
||||
'''
|
||||
|
||||
import argparse
|
||||
import tempfile
|
||||
import os.path as osp
|
||||
import gym
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.gail import mlp_policy
|
||||
from baselines import bench
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds, tf_util as U
|
||||
from baselines.common.misc_util import boolean_flag
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from baselines.gail.run_mujoco import runner
|
||||
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
|
||||
|
||||
|
||||
def argsparser():
|
||||
parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning")
|
||||
parser.add_argument('--env_id', help='environment ID', default='Hopper-v1')
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||
parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz')
|
||||
parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint')
|
||||
parser.add_argument('--log_dir', help='the directory to save log file', default='log')
|
||||
# Mujoco Dataset Configuration
|
||||
parser.add_argument('--traj_limitation', type=int, default=-1)
|
||||
# Network Configuration (Using MLP Policy)
|
||||
parser.add_argument('--policy_hidden_size', type=int, default=100)
|
||||
# for evaluatation
|
||||
boolean_flag(parser, 'stochastic_policy', default=False, help='use stochastic/deterministic policy to evaluate')
|
||||
boolean_flag(parser, 'save_sample', default=False, help='save the trajectories or not')
|
||||
parser.add_argument('--BC_max_iter', help='Max iteration for training BC', type=int, default=1e5)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
|
||||
adam_epsilon=1e-5, optim_stepsize=3e-4,
|
||||
ckpt_dir=None, log_dir=None, task_name=None,
|
||||
verbose=False):
|
||||
|
||||
val_per_iter = int(max_iters/10)
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
|
||||
# placeholder
|
||||
ob = U.get_placeholder_cached(name="ob")
|
||||
ac = pi.pdtype.sample_placeholder([None])
|
||||
stochastic = U.get_placeholder_cached(name="stochastic")
|
||||
loss = tf.reduce_mean(tf.square(ac-pi.ac))
|
||||
var_list = pi.get_trainable_variables()
|
||||
adam = MpiAdam(var_list, epsilon=adam_epsilon)
|
||||
lossandgrad = U.function([ob, ac, stochastic], [loss]+[U.flatgrad(loss, var_list)])
|
||||
|
||||
U.initialize()
|
||||
adam.sync()
|
||||
logger.log("Pretraining with Behavior Cloning...")
|
||||
for iter_so_far in tqdm(range(int(max_iters))):
|
||||
ob_expert, ac_expert = dataset.get_next_batch(optim_batch_size, 'train')
|
||||
train_loss, g = lossandgrad(ob_expert, ac_expert, True)
|
||||
adam.update(g, optim_stepsize)
|
||||
if verbose and iter_so_far % val_per_iter == 0:
|
||||
ob_expert, ac_expert = dataset.get_next_batch(-1, 'val')
|
||||
val_loss, _ = lossandgrad(ob_expert, ac_expert, True)
|
||||
logger.log("Training loss: {}, Validation loss: {}".format(train_loss, val_loss))
|
||||
|
||||
if ckpt_dir is None:
|
||||
savedir_fname = tempfile.TemporaryDirectory().name
|
||||
else:
|
||||
savedir_fname = osp.join(ckpt_dir, task_name)
|
||||
U.save_state(savedir_fname, var_list=pi.get_variables())
|
||||
return savedir_fname
|
||||
|
||||
|
||||
def get_task_name(args):
|
||||
task_name = 'BC'
|
||||
task_name += '.{}'.format(args.env_id.split("-")[0])
|
||||
task_name += '.traj_limitation_{}'.format(args.traj_limitation)
|
||||
task_name += ".seed_{}".format(args.seed)
|
||||
return task_name
|
||||
|
||||
|
||||
def main(args):
|
||||
U.make_session(num_cpu=1).__enter__()
|
||||
set_global_seeds(args.seed)
|
||||
env = gym.make(args.env_id)
|
||||
|
||||
def policy_fn(name, ob_space, ac_space, reuse=False):
|
||||
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
|
||||
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
|
||||
env = bench.Monitor(env, logger.get_dir() and
|
||||
osp.join(logger.get_dir(), "monitor.json"))
|
||||
env.seed(args.seed)
|
||||
gym.logger.setLevel(logging.WARN)
|
||||
task_name = get_task_name(args)
|
||||
args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
|
||||
args.log_dir = osp.join(args.log_dir, task_name)
|
||||
dataset = Mujoco_Dset(expert_path=args.expert_path, traj_limitation=args.traj_limitation)
|
||||
savedir_fname = learn(env,
|
||||
policy_fn,
|
||||
dataset,
|
||||
max_iters=args.BC_max_iter,
|
||||
ckpt_dir=args.checkpoint_dir,
|
||||
log_dir=args.log_dir,
|
||||
task_name=task_name,
|
||||
verbose=True)
|
||||
avg_len, avg_ret = runner(env,
|
||||
policy_fn,
|
||||
savedir_fname,
|
||||
timesteps_per_batch=1024,
|
||||
number_trajs=10,
|
||||
stochastic_policy=args.stochastic_policy,
|
||||
save=args.save_sample,
|
||||
reuse=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = argsparser()
|
||||
main(args)
|
0
baselines/gail/dataset/__init__.py
Normal file
116
baselines/gail/dataset/mujoco_dset.py
Normal file
@@ -0,0 +1,116 @@
|
||||
'''
|
||||
Data structure of the input .npz:
|
||||
the data is save in python dictionary format with keys: 'acs', 'ep_rets', 'rews', 'obs'
|
||||
the values of each item is a list storing the expert trajectory sequentially
|
||||
a transition can be: (data['obs'][t], data['acs'][t], data['obs'][t+1]) and get reward data['rews'][t]
|
||||
'''
|
||||
|
||||
from baselines import logger
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Dset(object):
|
||||
def __init__(self, inputs, labels, randomize):
|
||||
self.inputs = inputs
|
||||
self.labels = labels
|
||||
assert len(self.inputs) == len(self.labels)
|
||||
self.randomize = randomize
|
||||
self.num_pairs = len(inputs)
|
||||
self.init_pointer()
|
||||
|
||||
def init_pointer(self):
|
||||
self.pointer = 0
|
||||
if self.randomize:
|
||||
idx = np.arange(self.num_pairs)
|
||||
np.random.shuffle(idx)
|
||||
self.inputs = self.inputs[idx, :]
|
||||
self.labels = self.labels[idx, :]
|
||||
|
||||
def get_next_batch(self, batch_size):
|
||||
# if batch_size is negative -> return all
|
||||
if batch_size < 0:
|
||||
return self.inputs, self.labels
|
||||
if self.pointer + batch_size >= self.num_pairs:
|
||||
self.init_pointer()
|
||||
end = self.pointer + batch_size
|
||||
inputs = self.inputs[self.pointer:end, :]
|
||||
labels = self.labels[self.pointer:end, :]
|
||||
self.pointer = end
|
||||
return inputs, labels
|
||||
|
||||
|
||||
class Mujoco_Dset(object):
|
||||
def __init__(self, expert_path, train_fraction=0.7, traj_limitation=-1, randomize=True):
|
||||
traj_data = np.load(expert_path)
|
||||
if traj_limitation < 0:
|
||||
traj_limitation = len(traj_data['obs'])
|
||||
obs = traj_data['obs'][:traj_limitation]
|
||||
acs = traj_data['acs'][:traj_limitation]
|
||||
|
||||
def flatten(x):
|
||||
# x.shape = (E,), or (E, L, D)
|
||||
_, size = x[0].shape
|
||||
episode_length = [len(i) for i in x]
|
||||
y = np.zeros((sum(episode_length), size))
|
||||
start_idx = 0
|
||||
for l, x_i in zip(episode_length, x):
|
||||
y[start_idx:(start_idx+l)] = x_i
|
||||
start_idx += l
|
||||
return y
|
||||
self.obs = np.array(flatten(obs))
|
||||
self.acs = np.array(flatten(acs))
|
||||
self.rets = traj_data['ep_rets'][:traj_limitation]
|
||||
self.avg_ret = sum(self.rets)/len(self.rets)
|
||||
self.std_ret = np.std(np.array(self.rets))
|
||||
if len(self.acs) > 2:
|
||||
self.acs = np.squeeze(self.acs)
|
||||
assert len(self.obs) == len(self.acs)
|
||||
self.num_traj = min(traj_limitation, len(traj_data['obs']))
|
||||
self.num_transition = len(self.obs)
|
||||
self.randomize = randomize
|
||||
self.dset = Dset(self.obs, self.acs, self.randomize)
|
||||
# for behavior cloning
|
||||
self.train_set = Dset(self.obs[:int(self.num_transition*train_fraction), :],
|
||||
self.acs[:int(self.num_transition*train_fraction), :],
|
||||
self.randomize)
|
||||
self.val_set = Dset(self.obs[int(self.num_transition*train_fraction):, :],
|
||||
self.acs[int(self.num_transition*train_fraction):, :],
|
||||
self.randomize)
|
||||
self.log_info()
|
||||
|
||||
def log_info(self):
|
||||
logger.log("Total trajectorues: %d" % self.num_traj)
|
||||
logger.log("Total transitions: %d" % self.num_transition)
|
||||
logger.log("Average returns: %f" % self.avg_ret)
|
||||
logger.log("Std for returns: %f" % self.std_ret)
|
||||
|
||||
def get_next_batch(self, batch_size, split=None):
|
||||
if split is None:
|
||||
return self.dset.get_next_batch(batch_size)
|
||||
elif split == 'train':
|
||||
return self.train_set.get_next_batch(batch_size)
|
||||
elif split == 'val':
|
||||
return self.val_set.get_next_batch(batch_size)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def plot(self):
|
||||
import matplotlib.pyplot as plt
|
||||
plt.hist(self.rets)
|
||||
plt.savefig("histogram_rets.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def test(expert_path, traj_limitation, plot):
|
||||
dset = Mujoco_Dset(expert_path, traj_limitation=traj_limitation)
|
||||
if plot:
|
||||
dset.plot()
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--expert_path", type=str, default="../data/deterministic.trpo.Hopper.0.00.npz")
|
||||
parser.add_argument("--traj_limitation", type=int, default=None)
|
||||
parser.add_argument("--plot", type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
test(args.expert_path, args.traj_limitation, args.plot)
|
147
baselines/gail/gail-eval.py
Normal file
@@ -0,0 +1,147 @@
|
||||
'''
|
||||
This code is used to evalaute the imitators trained with different number of trajectories
|
||||
and plot the results in the same figure for easy comparison.
|
||||
'''
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import glob
|
||||
import gym
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from baselines.gail import run_mujoco
|
||||
from baselines.gail import mlp_policy
|
||||
from baselines.common import set_global_seeds, tf_util as U
|
||||
from baselines.common.misc_util import boolean_flag
|
||||
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
|
||||
|
||||
|
||||
plt.style.use('ggplot')
|
||||
CONFIG = {
|
||||
'traj_limitation': [1, 5, 10, 50],
|
||||
}
|
||||
|
||||
|
||||
def load_dataset(expert_path):
|
||||
dataset = Mujoco_Dset(expert_path=expert_path)
|
||||
return dataset
|
||||
|
||||
|
||||
def argsparser():
|
||||
parser = argparse.ArgumentParser('Do evaluation')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--policy_hidden_size', type=int, default=100)
|
||||
parser.add_argument('--env', type=str, choices=['Hopper', 'Walker2d', 'HalfCheetah',
|
||||
'Humanoid', 'HumanoidStandup'])
|
||||
boolean_flag(parser, 'stochastic_policy', default=False, help='use stochastic/deterministic policy to evaluate')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def evaluate_env(env_name, seed, policy_hidden_size, stochastic, reuse, prefix):
|
||||
|
||||
def get_checkpoint_dir(checkpoint_list, limit, prefix):
|
||||
for checkpoint in checkpoint_list:
|
||||
if ('limitation_'+str(limit) in checkpoint) and (prefix in checkpoint):
|
||||
return checkpoint
|
||||
return None
|
||||
|
||||
def policy_fn(name, ob_space, ac_space, reuse=False):
|
||||
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
|
||||
reuse=reuse, hid_size=policy_hidden_size, num_hid_layers=2)
|
||||
|
||||
data_path = os.path.join('data', 'deterministic.trpo.' + env_name + '.0.00.npz')
|
||||
dataset = load_dataset(data_path)
|
||||
checkpoint_list = glob.glob(os.path.join('checkpoint', '*' + env_name + ".*"))
|
||||
log = {
|
||||
'traj_limitation': [],
|
||||
'upper_bound': [],
|
||||
'avg_ret': [],
|
||||
'avg_len': [],
|
||||
'normalized_ret': []
|
||||
}
|
||||
for i, limit in enumerate(CONFIG['traj_limitation']):
|
||||
# Do one evaluation
|
||||
upper_bound = sum(dataset.rets[:limit])/limit
|
||||
checkpoint_dir = get_checkpoint_dir(checkpoint_list, limit, prefix=prefix)
|
||||
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
|
||||
env = gym.make(env_name + '-v1')
|
||||
env.seed(seed)
|
||||
print('Trajectory limitation: {}, Load checkpoint: {}, '.format(limit, checkpoint_path))
|
||||
avg_len, avg_ret = run_mujoco.runner(env,
|
||||
policy_fn,
|
||||
checkpoint_path,
|
||||
timesteps_per_batch=1024,
|
||||
number_trajs=10,
|
||||
stochastic_policy=stochastic,
|
||||
reuse=((i != 0) or reuse))
|
||||
normalized_ret = avg_ret/upper_bound
|
||||
print('Upper bound: {}, evaluation returns: {}, normalized scores: {}'.format(
|
||||
upper_bound, avg_ret, normalized_ret))
|
||||
log['traj_limitation'].append(limit)
|
||||
log['upper_bound'].append(upper_bound)
|
||||
log['avg_ret'].append(avg_ret)
|
||||
log['avg_len'].append(avg_len)
|
||||
log['normalized_ret'].append(normalized_ret)
|
||||
env.close()
|
||||
return log
|
||||
|
||||
|
||||
def plot(env_name, bc_log, gail_log, stochastic):
|
||||
upper_bound = bc_log['upper_bound']
|
||||
bc_avg_ret = bc_log['avg_ret']
|
||||
gail_avg_ret = gail_log['avg_ret']
|
||||
plt.plot(CONFIG['traj_limitation'], upper_bound)
|
||||
plt.plot(CONFIG['traj_limitation'], bc_avg_ret)
|
||||
plt.plot(CONFIG['traj_limitation'], gail_avg_ret)
|
||||
plt.xlabel('Number of expert trajectories')
|
||||
plt.ylabel('Accumulated reward')
|
||||
plt.title('{} unnormalized scores'.format(env_name))
|
||||
plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right')
|
||||
plt.grid(b=True, which='major', color='gray', linestyle='--')
|
||||
if stochastic:
|
||||
title_name = 'result/{}-unnormalized-stochastic-scores.png'.format(env_name)
|
||||
else:
|
||||
title_name = 'result/{}-unnormalized-deterministic-scores.png'.format(env_name)
|
||||
plt.savefig(title_name)
|
||||
plt.close()
|
||||
|
||||
bc_normalized_ret = bc_log['normalized_ret']
|
||||
gail_normalized_ret = gail_log['normalized_ret']
|
||||
plt.plot(CONFIG['traj_limitation'], np.ones(len(CONFIG['traj_limitation'])))
|
||||
plt.plot(CONFIG['traj_limitation'], bc_normalized_ret)
|
||||
plt.plot(CONFIG['traj_limitation'], gail_normalized_ret)
|
||||
plt.xlabel('Number of expert trajectories')
|
||||
plt.ylabel('Normalized performance')
|
||||
plt.title('{} normalized scores'.format(env_name))
|
||||
plt.legend(['expert', 'bc-imitator', 'gail-imitator'], loc='lower right')
|
||||
plt.grid(b=True, which='major', color='gray', linestyle='--')
|
||||
if stochastic:
|
||||
title_name = 'result/{}-normalized-stochastic-scores.png'.format(env_name)
|
||||
else:
|
||||
title_name = 'result/{}-normalized-deterministic-scores.png'.format(env_name)
|
||||
plt.ylim(0, 1.6)
|
||||
plt.savefig(title_name)
|
||||
plt.close()
|
||||
|
||||
|
||||
def main(args):
|
||||
U.make_session(num_cpu=1).__enter__()
|
||||
set_global_seeds(args.seed)
|
||||
print('Evaluating {}'.format(args.env))
|
||||
bc_log = evaluate_env(args.env, args.seed, args.policy_hidden_size,
|
||||
args.stochastic_policy, False, 'BC')
|
||||
print('Evaluation for {}'.format(args.env))
|
||||
print(bc_log)
|
||||
gail_log = evaluate_env(args.env, args.seed, args.policy_hidden_size,
|
||||
args.stochastic_policy, True, 'gail')
|
||||
print('Evaluation for {}'.format(args.env))
|
||||
print(gail_log)
|
||||
plot(args.env, bc_log, gail_log, args.stochastic_policy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = argsparser()
|
||||
main(args)
|
73
baselines/gail/mlp_policy.py
Normal file
@@ -0,0 +1,73 @@
|
||||
'''
|
||||
from baselines/ppo1/mlp_policy.py and add simple modification
|
||||
(1) add reuse argument
|
||||
(2) cache the `stochastic` placeholder
|
||||
'''
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
|
||||
import baselines.common.tf_util as U
|
||||
from baselines.common.mpi_running_mean_std import RunningMeanStd
|
||||
from baselines.common.distributions import make_pdtype
|
||||
|
||||
|
||||
class MlpPolicy(object):
|
||||
recurrent = False
|
||||
|
||||
def __init__(self, name, reuse=False, *args, **kwargs):
|
||||
with tf.variable_scope(name):
|
||||
if reuse:
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
self._init(*args, **kwargs)
|
||||
self.scope = tf.get_variable_scope().name
|
||||
|
||||
def _init(self, ob_space, ac_space, hid_size, num_hid_layers, gaussian_fixed_var=True):
|
||||
assert isinstance(ob_space, gym.spaces.Box)
|
||||
|
||||
self.pdtype = pdtype = make_pdtype(ac_space)
|
||||
sequence_length = None
|
||||
|
||||
ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[sequence_length] + list(ob_space.shape))
|
||||
|
||||
with tf.variable_scope("obfilter"):
|
||||
self.ob_rms = RunningMeanStd(shape=ob_space.shape)
|
||||
|
||||
obz = tf.clip_by_value((ob - self.ob_rms.mean) / self.ob_rms.std, -5.0, 5.0)
|
||||
last_out = obz
|
||||
for i in range(num_hid_layers):
|
||||
last_out = tf.nn.tanh(U.dense(last_out, hid_size, "vffc%i" % (i+1), weight_init=U.normc_initializer(1.0)))
|
||||
self.vpred = U.dense(last_out, 1, "vffinal", weight_init=U.normc_initializer(1.0))[:, 0]
|
||||
|
||||
last_out = obz
|
||||
for i in range(num_hid_layers):
|
||||
last_out = tf.nn.tanh(U.dense(last_out, hid_size, "polfc%i" % (i+1), weight_init=U.normc_initializer(1.0)))
|
||||
if gaussian_fixed_var and isinstance(ac_space, gym.spaces.Box):
|
||||
mean = U.dense(last_out, pdtype.param_shape()[0]//2, "polfinal", U.normc_initializer(0.01))
|
||||
logstd = tf.get_variable(name="logstd", shape=[1, pdtype.param_shape()[0]//2], initializer=tf.zeros_initializer())
|
||||
pdparam = U.concatenate([mean, mean * 0.0 + logstd], axis=1)
|
||||
else:
|
||||
pdparam = U.dense(last_out, pdtype.param_shape()[0], "polfinal", U.normc_initializer(0.01))
|
||||
|
||||
self.pd = pdtype.pdfromflat(pdparam)
|
||||
|
||||
self.state_in = []
|
||||
self.state_out = []
|
||||
|
||||
# change for BC
|
||||
stochastic = U.get_placeholder(name="stochastic", dtype=tf.bool, shape=())
|
||||
ac = U.switch(stochastic, self.pd.sample(), self.pd.mode())
|
||||
self.ac = ac
|
||||
self._act = U.function([stochastic, ob], [ac, self.vpred])
|
||||
|
||||
def act(self, stochastic, ob):
|
||||
ac1, vpred1 = self._act(stochastic, ob[None])
|
||||
return ac1[0], vpred1[0]
|
||||
|
||||
def get_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
|
||||
|
||||
def get_trainable_variables(self):
|
||||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)
|
||||
|
||||
def get_initial_state(self):
|
||||
return []
|
After Width: | Height: | Size: 33 KiB |
After Width: | Height: | Size: 41 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 52 KiB |
BIN
baselines/gail/result/Hopper-normalized-deterministic-scores.png
Normal file
After Width: | Height: | Size: 30 KiB |
BIN
baselines/gail/result/Hopper-normalized-stochastic-scores.png
Normal file
After Width: | Height: | Size: 42 KiB |
After Width: | Height: | Size: 33 KiB |
BIN
baselines/gail/result/Hopper-unnormalized-stochastic-scores.png
Normal file
After Width: | Height: | Size: 48 KiB |
After Width: | Height: | Size: 35 KiB |
BIN
baselines/gail/result/Humanoid-normalized-stochastic-scores.png
Normal file
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 46 KiB |
After Width: | Height: | Size: 32 KiB |
After Width: | Height: | Size: 40 KiB |
After Width: | Height: | Size: 45 KiB |
After Width: | Height: | Size: 49 KiB |
After Width: | Height: | Size: 31 KiB |
BIN
baselines/gail/result/Walker2d-normalized-stochastic-scores.png
Normal file
After Width: | Height: | Size: 41 KiB |
After Width: | Height: | Size: 38 KiB |
After Width: | Height: | Size: 46 KiB |
53
baselines/gail/result/gail-result.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Results of GAIL/BC on Mujoco
|
||||
|
||||
Here's the extensive experimental results of applying GAIL/BC on Mujoco environments, including
|
||||
Hopper-v1, Walker2d-v1, HalfCheetah-v1, Humanoid-v1, HumanoidStandup-v1. Every imitator is evaluated with seed to be 0.
|
||||
|
||||
## Results
|
||||
|
||||
### Training through iterations
|
||||
|
||||
- Hoppers-v1
|
||||
<img src='hopper-training.png'>
|
||||
|
||||
- HalfCheetah-v1
|
||||
<img src='halfcheetah-training.png'>
|
||||
|
||||
- Walker2d-v1
|
||||
<img src='walker2d-training.png'>
|
||||
|
||||
- Humanoid-v1
|
||||
<img src='humanoid-training.png'>
|
||||
|
||||
- HumanoidStandup-v1
|
||||
<img src='humanoidstandup-training.png'>
|
||||
|
||||
For details (e.g., adversarial loss, discriminator accuracy, etc.) about GAIL training, please see [here](https://drive.google.com/drive/folders/1nnU8dqAV9i37-_5_vWIspyFUJFQLCsDD?usp=sharing)
|
||||
|
||||
### Determinstic Polciy (Set std=0)
|
||||
| | Un-normalized | Normalized |
|
||||
|---|---|---|
|
||||
| Hopper-v1 | <img src='Hopper-unnormalized-deterministic-scores.png'> | <img src='Hopper-normalized-deterministic-scores.png'> |
|
||||
| HalfCheetah-v1 | <img src='HalfCheetah-unnormalized-deterministic-scores.png'> | <img src='HalfCheetah-normalized-deterministic-scores.png'> |
|
||||
| Walker2d-v1 | <img src='Walker2d-unnormalized-deterministic-scores.png'> | <img src='Walker2d-normalized-deterministic-scores.png'> |
|
||||
| Humanoid-v1 | <img src='Humanoid-unnormalized-deterministic-scores.png'> | <img src='Humanoid-normalized-deterministic-scores.png'> |
|
||||
| HumanoidStandup-v1 | <img src='HumanoidStandup-unnormalized-deterministic-scores.png'> | <img src='HumanoidStandup-normalized-deterministic-scores.png'> |
|
||||
|
||||
### Stochatic Policy
|
||||
| | Un-normalized | Normalized |
|
||||
|---|---|---|
|
||||
| Hopper-v1 | <img src='Hopper-unnormalized-stochastic-scores.png'> | <img src='Hopper-normalized-stochastic-scores.png'> |
|
||||
| HalfCheetah-v1 | <img src='HalfCheetah-unnormalized-stochastic-scores.png'> | <img src='HalfCheetah-normalized-stochastic-scores.png'> |
|
||||
| Walker2d-v1 | <img src='Walker2d-unnormalized-stochastic-scores.png'> | <img src='Walker2d-normalized-stochastic-scores.png'> |
|
||||
| Humanoid-v1 | <img src='Humanoid-unnormalized-stochastic-scores.png'> | <img src='Humanoid-normalized-stochastic-scores.png'> |
|
||||
| HumanoidStandup-v1 | <img src='HumanoidStandup-unnormalized-stochastic-scores.png'> | <img src='HumanoidStandup-normalized-stochastic-scores.png'> |
|
||||
|
||||
### details about GAIL imitator
|
||||
|
||||
For all environments, the
|
||||
imitator is trained with 1, 5, 10, 50 trajectories, where each trajectory contains at most
|
||||
1024 transitions, and seed 0, 1, 2, 3, respectively.
|
||||
|
||||
### details about the BC imitators
|
||||
|
||||
All BC imitators are trained with seed 0.
|
BIN
baselines/gail/result/halfcheetah-training.png
Normal file
After Width: | Height: | Size: 504 KiB |
BIN
baselines/gail/result/hopper-training.png
Normal file
After Width: | Height: | Size: 534 KiB |
BIN
baselines/gail/result/humanoid-training.png
Normal file
After Width: | Height: | Size: 538 KiB |
BIN
baselines/gail/result/humanoidstandup-training.png
Normal file
After Width: | Height: | Size: 684 KiB |
BIN
baselines/gail/result/walker2d-training.png
Normal file
After Width: | Height: | Size: 629 KiB |
239
baselines/gail/run_mujoco.py
Normal file
@@ -0,0 +1,239 @@
|
||||
'''
|
||||
Disclaimer: this code is highly based on trpo_mpi at @openai/baselines and @openai/imitation
|
||||
'''
|
||||
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import logging
|
||||
from mpi4py import MPI
|
||||
from tqdm import tqdm
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from baselines.gail import mlp_policy
|
||||
from baselines.common import set_global_seeds, tf_util as U
|
||||
from baselines.common.misc_util import boolean_flag
|
||||
from baselines import bench
|
||||
from baselines import logger
|
||||
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
|
||||
from baselines.gail.adversary import TransitionClassifier
|
||||
|
||||
|
||||
def argsparser():
|
||||
parser = argparse.ArgumentParser("Tensorflow Implementation of GAIL")
|
||||
parser.add_argument('--env_id', help='environment ID', default='Hopper-v1')
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||
parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz')
|
||||
parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint')
|
||||
parser.add_argument('--log_dir', help='the directory to save log file', default='log')
|
||||
parser.add_argument('--load_model_path', help='if provided, load the model', type=str, default=None)
|
||||
# Task
|
||||
parser.add_argument('--task', type=str, choices=['train', 'evaluate', 'sample'], default='train')
|
||||
# for evaluatation
|
||||
boolean_flag(parser, 'stochastic_policy', default=False, help='use stochastic/deterministic policy to evaluate')
|
||||
boolean_flag(parser, 'save_sample', default=False, help='save the trajectories or not')
|
||||
# Mujoco Dataset Configuration
|
||||
parser.add_argument('--traj_limitation', type=int, default=-1)
|
||||
# Optimization Configuration
|
||||
parser.add_argument('--g_step', help='number of steps to train policy in each epoch', type=int, default=3)
|
||||
parser.add_argument('--d_step', help='number of steps to train discriminator in each epoch', type=int, default=1)
|
||||
# Network Configuration (Using MLP Policy)
|
||||
parser.add_argument('--policy_hidden_size', type=int, default=100)
|
||||
parser.add_argument('--adversary_hidden_size', type=int, default=100)
|
||||
# Algorithms Configuration
|
||||
parser.add_argument('--algo', type=str, choices=['trpo', 'ppo'], default='trpo')
|
||||
parser.add_argument('--max_kl', type=float, default=0.01)
|
||||
parser.add_argument('--policy_entcoeff', help='entropy coefficiency of policy', type=float, default=0)
|
||||
parser.add_argument('--adversary_entcoeff', help='entropy coefficiency of discriminator', type=float, default=1e-3)
|
||||
# Traing Configuration
|
||||
parser.add_argument('--save_per_iter', help='save model every xx iterations', type=int, default=100)
|
||||
parser.add_argument('--num_timesteps', help='number of timesteps per episode', type=int, default=5e6)
|
||||
# Behavior Cloning
|
||||
boolean_flag(parser, 'pretrained', default=False, help='Use BC to pretrain')
|
||||
parser.add_argument('--BC_max_iter', help='Max iteration for training BC', type=int, default=1e4)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_task_name(args):
|
||||
task_name = args.algo + "_gail."
|
||||
if args.pretrained:
|
||||
task_name += "with_pretrained."
|
||||
if args.traj_limitation != np.inf:
|
||||
task_name += "transition_limitation_%d." % args.traj_limitation
|
||||
task_name += args.env_id.split("-")[0]
|
||||
task_name = task_name + ".g_step_" + str(args.g_step) + ".d_step_" + str(args.d_step) + \
|
||||
".policy_entcoeff_" + str(args.policy_entcoeff) + ".adversary_entcoeff_" + str(args.adversary_entcoeff)
|
||||
task_name += ".seed_" + str(args.seed)
|
||||
return task_name
|
||||
|
||||
|
||||
def main(args):
|
||||
U.make_session(num_cpu=1).__enter__()
|
||||
set_global_seeds(args.seed)
|
||||
env = gym.make(args.env_id)
|
||||
|
||||
def policy_fn(name, ob_space, ac_space, reuse=False):
|
||||
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
|
||||
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
|
||||
env = bench.Monitor(env, logger.get_dir() and
|
||||
osp.join(logger.get_dir(), "monitor.json"))
|
||||
env.seed(args.seed)
|
||||
gym.logger.setLevel(logging.WARN)
|
||||
task_name = get_task_name(args)
|
||||
args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
|
||||
args.log_dir = osp.join(args.log_dir, task_name)
|
||||
|
||||
if args.task == 'train':
|
||||
dataset = Mujoco_Dset(expert_path=args.expert_path, traj_limitation=args.traj_limitation)
|
||||
reward_giver = TransitionClassifier(env, args.adversary_hidden_size, entcoeff=args.adversary_entcoeff)
|
||||
train(env,
|
||||
args.seed,
|
||||
policy_fn,
|
||||
reward_giver,
|
||||
dataset,
|
||||
args.algo,
|
||||
args.g_step,
|
||||
args.d_step,
|
||||
args.policy_entcoeff,
|
||||
args.num_timesteps,
|
||||
args.save_per_iter,
|
||||
args.checkpoint_dir,
|
||||
args.log_dir,
|
||||
args.pretrained,
|
||||
args.BC_max_iter,
|
||||
task_name
|
||||
)
|
||||
elif args.task == 'evaluate':
|
||||
runner(env,
|
||||
policy_fn,
|
||||
args.load_model_path,
|
||||
timesteps_per_batch=1024,
|
||||
number_trajs=10,
|
||||
stochastic_policy=args.stochastic_policy,
|
||||
save=args.save_sample
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
env.close()
|
||||
|
||||
|
||||
def train(env, seed, policy_fn, reward_giver, dataset, algo,
|
||||
g_step, d_step, policy_entcoeff, num_timesteps, save_per_iter,
|
||||
checkpoint_dir, log_dir, pretrained, BC_max_iter, task_name=None):
|
||||
|
||||
pretrained_weight = None
|
||||
if pretrained and (BC_max_iter > 0):
|
||||
# Pretrain with behavior cloning
|
||||
from baselines.gail import behavior_clone
|
||||
pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
|
||||
max_iters=BC_max_iter)
|
||||
|
||||
if algo == 'trpo':
|
||||
from baselines.gail import trpo_mpi
|
||||
# Set up for MPI seed
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
if rank != 0:
|
||||
logger.set_level(logger.DISABLED)
|
||||
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
|
||||
set_global_seeds(workerseed)
|
||||
env.seed(workerseed)
|
||||
trpo_mpi.learn(env, policy_fn, reward_giver, dataset, rank,
|
||||
pretrained=pretrained, pretrained_weight=pretrained_weight,
|
||||
g_step=g_step, d_step=d_step,
|
||||
entcoeff=policy_entcoeff,
|
||||
max_timesteps=num_timesteps,
|
||||
ckpt_dir=checkpoint_dir, log_dir=log_dir,
|
||||
save_per_iter=save_per_iter,
|
||||
timesteps_per_batch=1024,
|
||||
max_kl=0.01, cg_iters=10, cg_damping=0.1,
|
||||
gamma=0.995, lam=0.97,
|
||||
vf_iters=5, vf_stepsize=1e-3,
|
||||
task_name=task_name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs,
|
||||
stochastic_policy, save=False, reuse=False):
|
||||
|
||||
# Setup network
|
||||
# ----------------------------------------
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
pi = policy_func("pi", ob_space, ac_space, reuse=reuse)
|
||||
U.initialize()
|
||||
# Prepare for rollouts
|
||||
# ----------------------------------------
|
||||
U.load_state(load_model_path)
|
||||
|
||||
obs_list = []
|
||||
acs_list = []
|
||||
len_list = []
|
||||
ret_list = []
|
||||
for _ in tqdm(range(number_trajs)):
|
||||
traj = traj_1_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy)
|
||||
obs, acs, ep_len, ep_ret = traj['ob'], traj['ac'], traj['ep_len'], traj['ep_ret']
|
||||
obs_list.append(obs)
|
||||
acs_list.append(acs)
|
||||
len_list.append(ep_len)
|
||||
ret_list.append(ep_ret)
|
||||
if stochastic_policy:
|
||||
print('stochastic policy:')
|
||||
else:
|
||||
print('deterministic policy:')
|
||||
if save:
|
||||
filename = load_model_path.split('/')[-1] + '.' + env.spec.id
|
||||
np.savez(filename, obs=np.array(obs_list), acs=np.array(acs_list),
|
||||
lens=np.array(len_list), rets=np.array(ret_list))
|
||||
avg_len = sum(len_list)/len(len_list)
|
||||
avg_ret = sum(ret_list)/len(ret_list)
|
||||
print("Average length:", avg_len)
|
||||
print("Average return:", avg_ret)
|
||||
return avg_len, avg_ret
|
||||
|
||||
|
||||
# Sample one trajectory (until trajectory end)
|
||||
def traj_1_generator(pi, env, horizon, stochastic):
|
||||
|
||||
t = 0
|
||||
ac = env.action_space.sample() # not used, just so we have the datatype
|
||||
new = True # marks if we're on first timestep of an episode
|
||||
|
||||
ob = env.reset()
|
||||
cur_ep_ret = 0 # return in current episode
|
||||
cur_ep_len = 0 # len of current episode
|
||||
|
||||
# Initialize history arrays
|
||||
obs = []
|
||||
rews = []
|
||||
news = []
|
||||
acs = []
|
||||
|
||||
while True:
|
||||
ac, vpred = pi.act(stochastic, ob)
|
||||
obs.append(ob)
|
||||
news.append(new)
|
||||
acs.append(ac)
|
||||
|
||||
ob, rew, new, _ = env.step(ac)
|
||||
rews.append(rew)
|
||||
|
||||
cur_ep_ret += rew
|
||||
cur_ep_len += 1
|
||||
if new or t >= horizon:
|
||||
break
|
||||
t += 1
|
||||
|
||||
obs = np.array(obs)
|
||||
rews = np.array(rews)
|
||||
news = np.array(news)
|
||||
acs = np.array(acs)
|
||||
traj = {"ob": obs, "rew": rews, "new": news, "ac": acs,
|
||||
"ep_ret": cur_ep_ret, "ep_len": cur_ep_len}
|
||||
return traj
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = argsparser()
|
||||
main(args)
|
45
baselines/gail/statistics.py
Normal file
@@ -0,0 +1,45 @@
|
||||
'''
|
||||
This code is highly based on https://github.com/carpedm20/deep-rl-tensorflow/blob/master/agents/statistic.py
|
||||
'''
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
import baselines.common.tf_util as U
|
||||
|
||||
|
||||
class stats():
|
||||
|
||||
def __init__(self, scalar_keys=[], histogram_keys=[]):
|
||||
self.scalar_keys = scalar_keys
|
||||
self.histogram_keys = histogram_keys
|
||||
self.scalar_summaries = []
|
||||
self.scalar_summaries_ph = []
|
||||
self.histogram_summaries_ph = []
|
||||
self.histogram_summaries = []
|
||||
with tf.variable_scope('summary'):
|
||||
for k in scalar_keys:
|
||||
ph = tf.placeholder('float32', None, name=k+'.scalar.summary')
|
||||
sm = tf.summary.scalar(k+'.scalar.summary', ph)
|
||||
self.scalar_summaries_ph.append(ph)
|
||||
self.scalar_summaries.append(sm)
|
||||
for k in histogram_keys:
|
||||
ph = tf.placeholder('float32', None, name=k+'.histogram.summary')
|
||||
sm = tf.summary.scalar(k+'.histogram.summary', ph)
|
||||
self.histogram_summaries_ph.append(ph)
|
||||
self.histogram_summaries.append(sm)
|
||||
|
||||
self.summaries = tf.summary.merge(self.scalar_summaries+self.histogram_summaries)
|
||||
|
||||
def add_all_summary(self, writer, values, iter):
|
||||
# Note that the order of the incoming ```values``` should be the same as the that of the
|
||||
# ```scalar_keys``` given in ```__init__```
|
||||
if np.sum(np.isnan(values)+0) != 0:
|
||||
return
|
||||
sess = U.get_session()
|
||||
keys = self.scalar_summaries_ph + self.histogram_summaries_ph
|
||||
feed_dict = {}
|
||||
for k, v in zip(keys, values):
|
||||
feed_dict.update({k: v})
|
||||
summaries_str = sess.run(self.summaries, feed_dict)
|
||||
writer.add_summary(summaries_str, iter)
|
355
baselines/gail/trpo_mpi.py
Normal file
@@ -0,0 +1,355 @@
|
||||
'''
|
||||
Disclaimer: The trpo part highly rely on trpo_mpi at @openai/baselines
|
||||
'''
|
||||
|
||||
import time
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from mpi4py import MPI
|
||||
from collections import deque
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
import baselines.common.tf_util as U
|
||||
from baselines.common import explained_variance, zipsame, dataset, fmt_row
|
||||
from baselines import logger
|
||||
from baselines.common import colorize
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from baselines.common.cg import cg
|
||||
from baselines.gail.statistics import stats
|
||||
|
||||
|
||||
def traj_segment_generator(pi, env, reward_giver, horizon, stochastic):
|
||||
|
||||
# Initialize state variables
|
||||
t = 0
|
||||
ac = env.action_space.sample()
|
||||
new = True
|
||||
rew = 0.0
|
||||
true_rew = 0.0
|
||||
ob = env.reset()
|
||||
|
||||
cur_ep_ret = 0
|
||||
cur_ep_len = 0
|
||||
cur_ep_true_ret = 0
|
||||
ep_true_rets = []
|
||||
ep_rets = []
|
||||
ep_lens = []
|
||||
|
||||
# Initialize history arrays
|
||||
obs = np.array([ob for _ in range(horizon)])
|
||||
true_rews = np.zeros(horizon, 'float32')
|
||||
rews = np.zeros(horizon, 'float32')
|
||||
vpreds = np.zeros(horizon, 'float32')
|
||||
news = np.zeros(horizon, 'int32')
|
||||
acs = np.array([ac for _ in range(horizon)])
|
||||
prevacs = acs.copy()
|
||||
|
||||
while True:
|
||||
prevac = ac
|
||||
ac, vpred = pi.act(stochastic, ob)
|
||||
# Slight weirdness here because we need value function at time T
|
||||
# before returning segment [0, T-1] so we get the correct
|
||||
# terminal value
|
||||
if t > 0 and t % horizon == 0:
|
||||
yield {"ob": obs, "rew": rews, "vpred": vpreds, "new": news,
|
||||
"ac": acs, "prevac": prevacs, "nextvpred": vpred * (1 - new),
|
||||
"ep_rets": ep_rets, "ep_lens": ep_lens, "ep_true_rets": ep_true_rets}
|
||||
_, vpred = pi.act(stochastic, ob)
|
||||
# Be careful!!! if you change the downstream algorithm to aggregate
|
||||
# several of these batches, then be sure to do a deepcopy
|
||||
ep_rets = []
|
||||
ep_true_rets = []
|
||||
ep_lens = []
|
||||
i = t % horizon
|
||||
obs[i] = ob
|
||||
vpreds[i] = vpred
|
||||
news[i] = new
|
||||
acs[i] = ac
|
||||
prevacs[i] = prevac
|
||||
|
||||
rew = reward_giver.get_reward(ob, ac)
|
||||
ob, true_rew, new, _ = env.step(ac)
|
||||
rews[i] = rew
|
||||
true_rews[i] = true_rew
|
||||
|
||||
cur_ep_ret += rew
|
||||
cur_ep_true_ret += true_rew
|
||||
cur_ep_len += 1
|
||||
if new:
|
||||
ep_rets.append(cur_ep_ret)
|
||||
ep_true_rets.append(cur_ep_true_ret)
|
||||
ep_lens.append(cur_ep_len)
|
||||
cur_ep_ret = 0
|
||||
cur_ep_true_ret = 0
|
||||
cur_ep_len = 0
|
||||
ob = env.reset()
|
||||
t += 1
|
||||
|
||||
|
||||
def add_vtarg_and_adv(seg, gamma, lam):
|
||||
new = np.append(seg["new"], 0) # last element is only used for last vtarg, but we already zeroed it if last new = 1
|
||||
vpred = np.append(seg["vpred"], seg["nextvpred"])
|
||||
T = len(seg["rew"])
|
||||
seg["adv"] = gaelam = np.empty(T, 'float32')
|
||||
rew = seg["rew"]
|
||||
lastgaelam = 0
|
||||
for t in reversed(range(T)):
|
||||
nonterminal = 1-new[t+1]
|
||||
delta = rew[t] + gamma * vpred[t+1] * nonterminal - vpred[t]
|
||||
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
||||
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|
||||
|
||||
|
||||
def learn(env, policy_func, reward_giver, expert_dataset, rank,
|
||||
pretrained, pretrained_weight, *,
|
||||
g_step, d_step, entcoeff, save_per_iter,
|
||||
ckpt_dir, log_dir, timesteps_per_batch, task_name,
|
||||
gamma, lam,
|
||||
max_kl, cg_iters, cg_damping=1e-2,
|
||||
vf_stepsize=3e-4, d_stepsize=3e-4, vf_iters=3,
|
||||
max_timesteps=0, max_episodes=0, max_iters=0,
|
||||
callback=None
|
||||
):
|
||||
|
||||
nworkers = MPI.COMM_WORLD.Get_size()
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
np.set_printoptions(precision=3)
|
||||
# Setup losses and stuff
|
||||
# ----------------------------------------
|
||||
ob_space = env.observation_space
|
||||
ac_space = env.action_space
|
||||
pi = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight != None))
|
||||
oldpi = policy_func("oldpi", ob_space, ac_space)
|
||||
atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
|
||||
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
|
||||
|
||||
ob = U.get_placeholder_cached(name="ob")
|
||||
ac = pi.pdtype.sample_placeholder([None])
|
||||
|
||||
kloldnew = oldpi.pd.kl(pi.pd)
|
||||
ent = pi.pd.entropy()
|
||||
meankl = U.mean(kloldnew)
|
||||
meanent = U.mean(ent)
|
||||
entbonus = entcoeff * meanent
|
||||
|
||||
vferr = U.mean(tf.square(pi.vpred - ret))
|
||||
|
||||
ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold
|
||||
surrgain = U.mean(ratio * atarg)
|
||||
|
||||
optimgain = surrgain + entbonus
|
||||
losses = [optimgain, meankl, entbonus, surrgain, meanent]
|
||||
loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"]
|
||||
|
||||
dist = meankl
|
||||
|
||||
all_var_list = pi.get_trainable_variables()
|
||||
var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
|
||||
vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
|
||||
d_adam = MpiAdam(reward_giver.get_trainable_variables())
|
||||
vfadam = MpiAdam(vf_var_list)
|
||||
|
||||
get_flat = U.GetFlat(var_list)
|
||||
set_from_flat = U.SetFromFlat(var_list)
|
||||
klgrads = tf.gradients(dist, var_list)
|
||||
flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
|
||||
shapes = [var.get_shape().as_list() for var in var_list]
|
||||
start = 0
|
||||
tangents = []
|
||||
for shape in shapes:
|
||||
sz = U.intprod(shape)
|
||||
tangents.append(tf.reshape(flat_tangent[start:start+sz], shape))
|
||||
start += sz
|
||||
gvp = tf.add_n([U.sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) # pylint: disable=E1111
|
||||
fvp = U.flatgrad(gvp, var_list)
|
||||
|
||||
assign_old_eq_new = U.function([], [], updates=[tf.assign(oldv, newv)
|
||||
for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
|
||||
compute_losses = U.function([ob, ac, atarg], losses)
|
||||
compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)])
|
||||
compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp)
|
||||
compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list))
|
||||
|
||||
@contextmanager
|
||||
def timed(msg):
|
||||
if rank == 0:
|
||||
print(colorize(msg, color='magenta'))
|
||||
tstart = time.time()
|
||||
yield
|
||||
print(colorize("done in %.3f seconds" % (time.time() - tstart), color='magenta'))
|
||||
else:
|
||||
yield
|
||||
|
||||
def allmean(x):
|
||||
assert isinstance(x, np.ndarray)
|
||||
out = np.empty_like(x)
|
||||
MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM)
|
||||
out /= nworkers
|
||||
return out
|
||||
|
||||
writer = U.file_writer(log_dir)
|
||||
U.initialize()
|
||||
th_init = get_flat()
|
||||
MPI.COMM_WORLD.Bcast(th_init, root=0)
|
||||
set_from_flat(th_init)
|
||||
d_adam.sync()
|
||||
vfadam.sync()
|
||||
if rank == 0:
|
||||
print("Init param sum", th_init.sum(), flush=True)
|
||||
|
||||
# Prepare for rollouts
|
||||
# ----------------------------------------
|
||||
seg_gen = traj_segment_generator(pi, env, reward_giver, timesteps_per_batch, stochastic=True)
|
||||
|
||||
episodes_so_far = 0
|
||||
timesteps_so_far = 0
|
||||
iters_so_far = 0
|
||||
tstart = time.time()
|
||||
lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths
|
||||
rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards
|
||||
true_rewbuffer = deque(maxlen=40)
|
||||
|
||||
assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1
|
||||
|
||||
g_loss_stats = stats(loss_names)
|
||||
d_loss_stats = stats(reward_giver.loss_name)
|
||||
ep_stats = stats(["True_rewards", "Rewards", "Episode_length"])
|
||||
# if provide pretrained weight
|
||||
if pretrained_weight is not None:
|
||||
U.load_state(pretrained_weight, var_list=pi.get_variables())
|
||||
|
||||
while True:
|
||||
if callback: callback(locals(), globals())
|
||||
if max_timesteps and timesteps_so_far >= max_timesteps:
|
||||
break
|
||||
elif max_episodes and episodes_so_far >= max_episodes:
|
||||
break
|
||||
elif max_iters and iters_so_far >= max_iters:
|
||||
break
|
||||
|
||||
# Save model
|
||||
if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
|
||||
U.save_state(os.path.join(ckpt_dir, task_name))
|
||||
|
||||
logger.log("********** Iteration %i ************" % iters_so_far)
|
||||
|
||||
def fisher_vector_product(p):
|
||||
return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p
|
||||
# ------------------ Update G ------------------
|
||||
logger.log("Optimizing Policy...")
|
||||
for _ in range(g_step):
|
||||
with timed("sampling"):
|
||||
seg = seg_gen.__next__()
|
||||
add_vtarg_and_adv(seg, gamma, lam)
|
||||
# ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
|
||||
ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"]
|
||||
vpredbefore = seg["vpred"] # predicted value function before udpate
|
||||
atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate
|
||||
|
||||
if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy
|
||||
|
||||
args = seg["ob"], seg["ac"], atarg
|
||||
fvpargs = [arr[::5] for arr in args]
|
||||
|
||||
assign_old_eq_new() # set old parameter values to new parameter values
|
||||
with timed("computegrad"):
|
||||
*lossbefore, g = compute_lossandgrad(*args)
|
||||
lossbefore = allmean(np.array(lossbefore))
|
||||
g = allmean(g)
|
||||
if np.allclose(g, 0):
|
||||
logger.log("Got zero gradient. not updating")
|
||||
else:
|
||||
with timed("cg"):
|
||||
stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank == 0)
|
||||
assert np.isfinite(stepdir).all()
|
||||
shs = .5*stepdir.dot(fisher_vector_product(stepdir))
|
||||
lm = np.sqrt(shs / max_kl)
|
||||
# logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
|
||||
fullstep = stepdir / lm
|
||||
expectedimprove = g.dot(fullstep)
|
||||
surrbefore = lossbefore[0]
|
||||
stepsize = 1.0
|
||||
thbefore = get_flat()
|
||||
for _ in range(10):
|
||||
thnew = thbefore + fullstep * stepsize
|
||||
set_from_flat(thnew)
|
||||
meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args)))
|
||||
improve = surr - surrbefore
|
||||
logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve))
|
||||
if not np.isfinite(meanlosses).all():
|
||||
logger.log("Got non-finite value of losses -- bad!")
|
||||
elif kl > max_kl * 1.5:
|
||||
logger.log("violated KL constraint. shrinking step.")
|
||||
elif improve < 0:
|
||||
logger.log("surrogate didn't improve. shrinking step.")
|
||||
else:
|
||||
logger.log("Stepsize OK!")
|
||||
break
|
||||
stepsize *= .5
|
||||
else:
|
||||
logger.log("couldn't compute a good step")
|
||||
set_from_flat(thbefore)
|
||||
if nworkers > 1 and iters_so_far % 20 == 0:
|
||||
paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples
|
||||
assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
|
||||
with timed("vf"):
|
||||
for _ in range(vf_iters):
|
||||
for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
|
||||
include_final_partial_batch=False, batch_size=128):
|
||||
if hasattr(pi, "ob_rms"):
|
||||
pi.ob_rms.update(mbob) # update running mean/std for policy
|
||||
g = allmean(compute_vflossandgrad(mbob, mbret))
|
||||
vfadam.update(g, vf_stepsize)
|
||||
|
||||
g_losses = meanlosses
|
||||
for (lossname, lossval) in zip(loss_names, meanlosses):
|
||||
logger.record_tabular(lossname, lossval)
|
||||
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
|
||||
# ------------------ Update D ------------------
|
||||
logger.log("Optimizing Discriminator...")
|
||||
logger.log(fmt_row(13, reward_giver.loss_name))
|
||||
ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob))
|
||||
batch_size = len(ob) // d_step
|
||||
d_losses = [] # list of tuples, each of which gives the loss for a minibatch
|
||||
for ob_batch, ac_batch in dataset.iterbatches((ob, ac),
|
||||
include_final_partial_batch=False,
|
||||
batch_size=batch_size):
|
||||
ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch))
|
||||
# update running mean/std for reward_giver
|
||||
if hasattr(reward_giver, "obs_rms"): reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))
|
||||
*newlosses, g = reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
|
||||
d_adam.update(allmean(g), d_stepsize)
|
||||
d_losses.append(newlosses)
|
||||
logger.log(fmt_row(13, np.mean(d_losses, axis=0)))
|
||||
|
||||
lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]) # local values
|
||||
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
|
||||
lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs))
|
||||
true_rewbuffer.extend(true_rets)
|
||||
lenbuffer.extend(lens)
|
||||
rewbuffer.extend(rews)
|
||||
|
||||
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
|
||||
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
|
||||
logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer))
|
||||
logger.record_tabular("EpThisIter", len(lens))
|
||||
episodes_so_far += len(lens)
|
||||
timesteps_so_far += sum(lens)
|
||||
iters_so_far += 1
|
||||
|
||||
logger.record_tabular("EpisodesSoFar", episodes_so_far)
|
||||
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
|
||||
logger.record_tabular("TimeElapsed", time.time() - tstart)
|
||||
|
||||
if rank == 0:
|
||||
logger.dump_tabular()
|
||||
g_loss_stats.add_all_summary(writer, g_losses, iters_so_far)
|
||||
d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0), iters_so_far)
|
||||
ep_stats.add_all_summary(writer, [np.mean(true_rewbuffer), np.mean(rewbuffer),
|
||||
np.mean(lenbuffer)], iters_so_far)
|
||||
|
||||
|
||||
def flatten_lists(listoflists):
|
||||
return [el for list_ in listoflists for el in list_]
|