From e5a714b070ac1ba40194f3d1ddfed29ebd0549ea Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 12 Jan 2018 15:12:45 -0800 Subject: [PATCH] fix relative import --- baselines/gail/README.md | 2 +- baselines/gail/behavior_clone.py | 6 +++--- baselines/gail/gail-eval.py | 6 +++--- baselines/gail/run_mujoco.py | 10 +++++----- baselines/gail/trpo_mpi.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/baselines/gail/README.md b/baselines/gail/README.md index 911e447..112cfec 100644 --- a/baselines/gail/README.md +++ b/baselines/gail/README.md @@ -41,4 +41,4 @@ Thanks to the open source: - @openai/imitation - @carpedm20/deep-rl-tensorflow -Also, thanks [Ryan Jilian](https://github.com/ryanjulian) for reviewing the code +Also, thanks [Ryan Julian](https://github.com/ryanjulian) for reviewing the code diff --git a/baselines/gail/behavior_clone.py b/baselines/gail/behavior_clone.py index 5164d85..82f65ec 100644 --- a/baselines/gail/behavior_clone.py +++ b/baselines/gail/behavior_clone.py @@ -11,14 +11,14 @@ from tqdm import tqdm import tensorflow as tf -import mlp_policy +from baselines.gail import mlp_policy from baselines import bench from baselines import logger from baselines.common import set_global_seeds, tf_util as U from baselines.common.misc_util import boolean_flag from baselines.common.mpi_adam import MpiAdam -from run_mujoco import runner -from dataset.mujoco_dset import Mujoco_Dset +from baselines.gail.run_mujoco import runner +from baselines.gail.dataset.mujoco_dset import Mujoco_Dset def argsparser(): diff --git a/baselines/gail/gail-eval.py b/baselines/gail/gail-eval.py index ff4487f..1148cb3 100644 --- a/baselines/gail/gail-eval.py +++ b/baselines/gail/gail-eval.py @@ -12,11 +12,11 @@ import matplotlib.pyplot as plt import numpy as np import tensorflow as tf -import run_mujoco -import mlp_policy +from baselines.gail import run_mujoco +from baselines.gail import mlp_policy from baselines.common import set_global_seeds, tf_util as U from baselines.common.misc_util import boolean_flag -from dataset.mujoco_dset import Mujoco_Dset +from baselines.gail.dataset.mujoco_dset import Mujoco_Dset plt.style.use('ggplot') diff --git a/baselines/gail/run_mujoco.py b/baselines/gail/run_mujoco.py index f3cc213..600f847 100644 --- a/baselines/gail/run_mujoco.py +++ b/baselines/gail/run_mujoco.py @@ -11,13 +11,13 @@ from tqdm import tqdm import numpy as np import gym -import mlp_policy +from baselines.gail import mlp_policy from baselines.common import set_global_seeds, tf_util as U from baselines.common.misc_util import boolean_flag from baselines import bench from baselines import logger -from dataset.mujoco_dset import Mujoco_Dset -from adversary import TransitionClassifier +from baselines.gail.dataset.mujoco_dset import Mujoco_Dset +from baselines.gail.adversary import TransitionClassifier def argsparser(): @@ -125,12 +125,12 @@ def train(env, seed, policy_fn, reward_giver, dataset, algo, pretrained_weight = None if pretrained and (BC_max_iter > 0): # Pretrain with behavior cloning - import behavior_clone + from baselines.gail import behavior_clone pretrained_weight = behavior_clone.learn(env, policy_fn, dataset, max_iters=BC_max_iter) if algo == 'trpo': - import trpo_mpi + from baselines.gail import trpo_mpi # Set up for MPI seed rank = MPI.COMM_WORLD.Get_rank() if rank != 0: diff --git a/baselines/gail/trpo_mpi.py b/baselines/gail/trpo_mpi.py index f40aae4..98a0caa 100644 --- a/baselines/gail/trpo_mpi.py +++ b/baselines/gail/trpo_mpi.py @@ -17,7 +17,7 @@ from baselines import logger from baselines.common import colorize from baselines.common.mpi_adam import MpiAdam from baselines.common.cg import cg -from statistics import stats +from baselines.gail.statistics import stats def traj_segment_generator(pi, env, reward_giver, horizon, stochastic):