fix relative import
This commit is contained in:
@@ -41,4 +41,4 @@ Thanks to the open source:
|
||||
- @openai/imitation
|
||||
- @carpedm20/deep-rl-tensorflow
|
||||
|
||||
Also, thanks [Ryan Jilian](https://github.com/ryanjulian) for reviewing the code
|
||||
Also, thanks [Ryan Julian](https://github.com/ryanjulian) for reviewing the code
|
||||
|
@@ -11,14 +11,14 @@ from tqdm import tqdm
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import mlp_policy
|
||||
from baselines.gail import mlp_policy
|
||||
from baselines import bench
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds, tf_util as U
|
||||
from baselines.common.misc_util import boolean_flag
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from run_mujoco import runner
|
||||
from dataset.mujoco_dset import Mujoco_Dset
|
||||
from baselines.gail.run_mujoco import runner
|
||||
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
|
||||
|
||||
|
||||
def argsparser():
|
||||
|
@@ -12,11 +12,11 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import run_mujoco
|
||||
import mlp_policy
|
||||
from baselines.gail import run_mujoco
|
||||
from baselines.gail import mlp_policy
|
||||
from baselines.common import set_global_seeds, tf_util as U
|
||||
from baselines.common.misc_util import boolean_flag
|
||||
from dataset.mujoco_dset import Mujoco_Dset
|
||||
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
|
||||
|
||||
|
||||
plt.style.use('ggplot')
|
||||
|
@@ -11,13 +11,13 @@ from tqdm import tqdm
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
import mlp_policy
|
||||
from baselines.gail import mlp_policy
|
||||
from baselines.common import set_global_seeds, tf_util as U
|
||||
from baselines.common.misc_util import boolean_flag
|
||||
from baselines import bench
|
||||
from baselines import logger
|
||||
from dataset.mujoco_dset import Mujoco_Dset
|
||||
from adversary import TransitionClassifier
|
||||
from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
|
||||
from baselines.gail.adversary import TransitionClassifier
|
||||
|
||||
|
||||
def argsparser():
|
||||
@@ -125,12 +125,12 @@ def train(env, seed, policy_fn, reward_giver, dataset, algo,
|
||||
pretrained_weight = None
|
||||
if pretrained and (BC_max_iter > 0):
|
||||
# Pretrain with behavior cloning
|
||||
import behavior_clone
|
||||
from baselines.gail import behavior_clone
|
||||
pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
|
||||
max_iters=BC_max_iter)
|
||||
|
||||
if algo == 'trpo':
|
||||
import trpo_mpi
|
||||
from baselines.gail import trpo_mpi
|
||||
# Set up for MPI seed
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
if rank != 0:
|
||||
|
@@ -17,7 +17,7 @@ from baselines import logger
|
||||
from baselines.common import colorize
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from baselines.common.cg import cg
|
||||
from statistics import stats
|
||||
from baselines.gail.statistics import stats
|
||||
|
||||
|
||||
def traj_segment_generator(pi, env, reward_giver, horizon, stochastic):
|
||||
|
Reference in New Issue
Block a user