Files
baselines/baselines/acktr/run_atari.py
Ryan Julian df889caf11 Force shebang lines to Python 3
This is a Python 3-only library. A shebang with `#!/usr/bin/env python`
will launch python2 on many systems which do not have python3
installed. Setting the shebang to `#!/usr/bin/env python3` will show a
useful error on systems without Python 3.
2017-11-05 15:22:16 -08:00

39 lines
1.5 KiB
Python

#!/usr/bin/env python3
import os, logging, gym
from baselines import logger
from baselines.common import set_global_seeds
from baselines import bench
from baselines.acktr.acktr_disc import learn
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.acktr.policies import CnnPolicy
def train(env_id, num_timesteps, seed, num_cpu):
def make_env(rank):
def _thunk():
env = make_atari(env_id)
env.seed(seed + rank)
env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
gym.logger.setLevel(logging.WARN)
return wrap_deepmind(env)
return _thunk
set_global_seeds(seed)
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
policy_fn = CnnPolicy
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu)
env.close()
def main():
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
args = parser.parse_args()
logger.configure()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32)
if __name__ == '__main__':
main()