41 lines
1.5 KiB
Python
41 lines
1.5 KiB
Python
#!/usr/bin/env python3
|
|
import sys
|
|
from baselines import logger
|
|
from baselines.common.cmd_util import make_atari_env, atari_arg_parser
|
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
|
from baselines.ppo2 import ppo2
|
|
from baselines.ppo2.policies import CnnPolicy, LstmPolicy, LnLstmPolicy, MlpPolicy
|
|
import multiprocessing
|
|
import tensorflow as tf
|
|
|
|
|
|
def train(env_id, num_timesteps, seed, policy):
|
|
|
|
ncpu = multiprocessing.cpu_count()
|
|
if sys.platform == 'darwin': ncpu //= 2
|
|
config = tf.ConfigProto(allow_soft_placement=True,
|
|
intra_op_parallelism_threads=ncpu,
|
|
inter_op_parallelism_threads=ncpu)
|
|
config.gpu_options.allow_growth = True #pylint: disable=E1101
|
|
tf.Session(config=config).__enter__()
|
|
|
|
env = VecFrameStack(make_atari_env(env_id, 8, seed), 4)
|
|
policy = {'cnn' : CnnPolicy, 'lstm' : LstmPolicy, 'lnlstm' : LnLstmPolicy, 'mlp': MlpPolicy}[policy]
|
|
ppo2.learn(policy=policy, env=env, nsteps=128, nminibatches=4,
|
|
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
|
|
ent_coef=.01,
|
|
lr=lambda f : f * 2.5e-4,
|
|
cliprange=lambda f : f * 0.1,
|
|
total_timesteps=int(num_timesteps * 1.1))
|
|
|
|
def main():
|
|
parser = atari_arg_parser()
|
|
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn')
|
|
args = parser.parse_args()
|
|
logger.configure()
|
|
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,
|
|
policy=args.policy)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|