Files
baselines/baselines/acktr/run_mujoco.py
John Schulman bb40378118 change atari preprocessing to use faster opencv
some logger changes
2017-10-25 09:21:29 -04:00

44 lines
1.6 KiB
Python

#!/usr/bin/env python
import argparse
import logging
import os
import tensorflow as tf
import gym
from baselines import logger
from baselines.common import set_global_seeds
from baselines import bench
from baselines.acktr.acktr_cont import learn
from baselines.acktr.policies import GaussianMlpPolicy
from baselines.acktr.value_functions import NeuralNetValueFunction
def train(env_id, num_timesteps, seed):
env=gym.make(env_id)
env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
set_global_seeds(seed)
env.seed(seed)
gym.logger.setLevel(logging.WARN)
with tf.Session(config=tf.ConfigProto()):
ob_dim = env.observation_space.shape[0]
ac_dim = env.action_space.shape[0]
with tf.variable_scope("vf"):
vf = NeuralNetValueFunction(ob_dim, ac_dim)
with tf.variable_scope("pi"):
policy = GaussianMlpPolicy(ob_dim, ac_dim)
learn(env, policy=policy, vf=vf,
gamma=0.99, lam=0.97, timesteps_per_batch=2500,
desired_kl=0.002,
num_timesteps=num_timesteps, animate=False)
env.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Mujoco benchmark.')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1")
parser.add_argument('--num-timesteps', type=int, default=int(1e6))
args = parser.parse_args()
logger.configure()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed)