Files
baselines/baselines/trpo_mpi/run_mujoco.py
20chase 4a77855529 using mujoco_arg_parser as args
remove origin parser
2018-01-29 16:52:01 +08:00

37 lines
1.2 KiB
Python

#!/usr/bin/env python3
# noinspection PyUnresolvedReferences
from mpi4py import MPI
from baselines.common.cmd_util import make_mujoco_env, mujoco_arg_parser
from baselines import logger
from baselines.ppo1.mlp_policy import MlpPolicy
from baselines.trpo_mpi import trpo_mpi
def train(env_id, num_timesteps, seed):
import baselines.common.tf_util as U
sess = U.single_threaded_session()
sess.__enter__()
rank = MPI.COMM_WORLD.Get_rank()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
hid_size=32, num_hid_layers=2)
env = make_mujoco_env(env_id, workerseed)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
env.close()
def main():
args = mujoco_arg_parser().parse_args()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)
if __name__ == '__main__':
main()