fix relative import

This commit is contained in:
andrew
2018-01-12 15:12:45 -08:00
parent f22bee085d
commit e5a714b070
5 changed files with 13 additions and 13 deletions

View File

@@ -41,4 +41,4 @@ Thanks to the open source:
- @openai/imitation
- @carpedm20/deep-rl-tensorflow
Also, thanks [Ryan Jilian](https://github.com/ryanjulian) for reviewing the code
Also, thanks [Ryan Julian](https://github.com/ryanjulian) for reviewing the code

View File

@@ -11,14 +11,14 @@ from tqdm import tqdm
import tensorflow as tf
import mlp_policy
from baselines.gail import mlp_policy
from baselines import bench
from baselines import logger
from baselines.common import set_global_seeds, tf_util as U
from baselines.common.misc_util import boolean_flag
from baselines.common.mpi_adam import MpiAdam
from run_mujoco import runner
from dataset.mujoco_dset import Mujoco_Dset
from baselines.gail.run_mujoco import runner
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
def argsparser():

View File

@@ -12,11 +12,11 @@ import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import run_mujoco
import mlp_policy
from baselines.gail import run_mujoco
from baselines.gail import mlp_policy
from baselines.common import set_global_seeds, tf_util as U
from baselines.common.misc_util import boolean_flag
from dataset.mujoco_dset import Mujoco_Dset
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
plt.style.use('ggplot')

View File

@@ -11,13 +11,13 @@ from tqdm import tqdm
import numpy as np
import gym
import mlp_policy
from baselines.gail import mlp_policy
from baselines.common import set_global_seeds, tf_util as U
from baselines.common.misc_util import boolean_flag
from baselines import bench
from baselines import logger
from dataset.mujoco_dset import Mujoco_Dset
from adversary import TransitionClassifier
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
from baselines.gail.adversary import TransitionClassifier
def argsparser():
@@ -125,12 +125,12 @@ def train(env, seed, policy_fn, reward_giver, dataset, algo,
pretrained_weight = None
if pretrained and (BC_max_iter > 0):
# Pretrain with behavior cloning
import behavior_clone
from baselines.gail import behavior_clone
pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
max_iters=BC_max_iter)
if algo == 'trpo':
import trpo_mpi
from baselines.gail import trpo_mpi
# Set up for MPI seed
rank = MPI.COMM_WORLD.Get_rank()
if rank != 0:

View File

@@ -17,7 +17,7 @@ from baselines import logger
from baselines.common import colorize
from baselines.common.mpi_adam import MpiAdam
from baselines.common.cg import cg
from statistics import stats
from baselines.gail.statistics import stats
def traj_segment_generator(pi, env, reward_giver, horizon, stochastic):