This is a Python 3-only library. A shebang with `#!/usr/bin/env python` will launch python2 on many systems which do not have python3 installed. Setting the shebang to `#!/usr/bin/env python3` will show a useful error on systems without Python 3.
39 lines
1.5 KiB
Python
39 lines
1.5 KiB
Python
#!/usr/bin/env python3
|
|
import os, logging, gym
|
|
from baselines import logger
|
|
from baselines.common import set_global_seeds
|
|
from baselines import bench
|
|
from baselines.acktr.acktr_disc import learn
|
|
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
|
from baselines.acktr.policies import CnnPolicy
|
|
|
|
def train(env_id, num_timesteps, seed, num_cpu):
|
|
def make_env(rank):
|
|
def _thunk():
|
|
env = make_atari(env_id)
|
|
env.seed(seed + rank)
|
|
env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
|
|
gym.logger.setLevel(logging.WARN)
|
|
return wrap_deepmind(env)
|
|
return _thunk
|
|
set_global_seeds(seed)
|
|
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
|
|
policy_fn = CnnPolicy
|
|
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu)
|
|
env.close()
|
|
|
|
def main():
|
|
import argparse
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
|
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
|
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
|
|
args = parser.parse_args()
|
|
logger.configure()
|
|
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|