add behavior cloning learn/eval code

This commit is contained in:
andrew
2017-12-03 13:55:44 -08:00
parent 8495890534
commit 7954327c5f

View File

@@ -1,34 +1,38 @@
import argparse
import tempfile
import os.path as osp
import gym
import logging
from tqdm import tqdm
import tensorflow as tf
import baselines.common.tf_util as U
import mlp_policy
from baselines import bench
from baselines import logger
from baselines.common import set_global_seeds, tf_util as U
from baselines.common.misc_util import boolean_flag
from baselines.common.mpi_adam import MpiAdam
from run_mujoco import runner
from dataset.mujoco_dset import Mujoco_Dset
def evaluate(env, policy_func, load_model_path, stochastic_policy=False, number_trajs=10):
from algo.trpo_mpi import traj_episode_generator
ob_space = env.observation_space
ac_space = env.action_space
pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy
# placeholder
ep_gen = traj_episode_generator(pi, env, 1024, stochastic=stochastic_policy)
U.load_state(load_model_path)
len_list = []
ret_list = []
for _ in tqdm(range(number_trajs)):
traj = ep_gen.__next__()
ep_len, ep_ret = traj['ep_len'], traj['ep_ret']
len_list.append(ep_len)
ret_list.append(ep_ret)
if stochastic_policy:
print('stochastic policy:')
else:
print('deterministic policy:')
print("Average length:", sum(len_list)/len(len_list))
print("Average return:", sum(ret_list)/len(ret_list))
def argsparser():
parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning")
parser.add_argument('--env_id', help='environment ID', default='Hopper-v1')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz')
parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint')
parser.add_argument('--log_dir', help='the directory to save log file', default='log')
# Mujoco Dataset Configuration
parser.add_argument('--traj_limitation', type=int, default=-1)
# Network Configuration (Using MLP Policy)
parser.add_argument('--policy_hidden_size', type=int, default=100)
# for evaluatation
boolean_flag(parser, 'stochastic_policy', default=False, help='use stochastic/deterministic policy to evaluate')
boolean_flag(parser, 'save_sample', default=False, help='save the trajectories or not')
parser.add_argument('--BC_max_iter', help='Max iteration for training BC', type=int, default=1e5)
return parser.parse_args()
def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
@@ -54,14 +58,63 @@ def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
logger.log("Pretraining with Behavior Cloning...")
for iter_so_far in tqdm(range(int(max_iters))):
ob_expert, ac_expert = dataset.get_next_batch(optim_batch_size, 'train')
loss, g = lossandgrad(ob_expert, ac_expert, True)
train_loss, g = lossandgrad(ob_expert, ac_expert, True)
adam.update(g, optim_stepsize)
if verbose and iter_so_far % val_per_iter == 0:
ob_expert, ac_expert = dataset.get_next_batch(-1, 'val')
loss, g = lossandgrad(ob_expert, ac_expert, False)
logger.log("Validation:")
logger.log("Loss: %f" % loss)
val_loss, _ = lossandgrad(ob_expert, ac_expert, True)
logger.log("Training loss: {}, Validation loss: {}".format(train_loss, val_loss))
savedir_fname = tempfile.TemporaryDirectory().name
if ckpt_dir is None:
savedir_fname = tempfile.TemporaryDirectory().name
else:
savedir_fname = osp.join(ckpt_dir, task_name)
U.save_state(savedir_fname, var_list=pi.get_variables())
return savedir_fname
def get_task_name(args):
task_name = 'BC'
task_name += '.{}'.format(args.env_id.split("-")[0])
task_name += '.traj_limitation_{}'.format(args.traj_limitation)
task_name += ".seed_{}".format(args.seed)
return task_name
def main(args):
U.make_session(num_cpu=1).__enter__()
set_global_seeds(args.seed)
env = gym.make(args.env_id)
def policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), "monitor.json"))
env.seed(args.seed)
gym.logger.setLevel(logging.WARN)
task_name = get_task_name(args)
args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
args.log_dir = osp.join(args.log_dir, task_name)
dataset = Mujoco_Dset(expert_path=args.expert_path, traj_limitation=args.traj_limitation)
savedir_fname = learn(env,
policy_fn,
dataset,
max_iters=args.BC_max_iter,
ckpt_dir=args.checkpoint_dir,
log_dir=args.log_dir,
task_name=task_name,
verbose=True)
avg_len, avg_ret = runner(env,
policy_fn,
savedir_fname,
timesteps_per_batch=1024,
number_trajs=10,
stochastic_policy=args.stochastic_policy,
save=args.save_sample,
reuse=True)
if __name__ == '__main__':
args = argsparser()
main(args)