Files
baselines/baselines/deepq/experiments/train_pong.py
2017-05-25 14:53:29 -07:00

35 lines
800 B
Python

import gym
from baselines import deepq
from baselines.common.atari_wrappers_deprecated import wrap_dqn, ScaledFloatFrame
def main():
env = gym.make("PongNoFrameskip-v4")
env = ScaledFloatFrame(wrap_dqn(env))
model = deepq.models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True
)
act = deepq.learn(
env,
q_func=model,
lr=1e-4,
max_timesteps=2000000,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True
)
act.save("pong_model.pkl")
env.close()
if __name__ == '__main__':
main()