Files
baselines/baselines/deepq/experiments/train_cartpole.py
2018-08-16 13:15:51 -07:00

31 lines
646 B
Python

import gym
from baselines import deepq
def callback(lcl, _glb):
# stop training if reward exceeds 199
is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
return is_solved
def main():
env = gym.make("CartPole-v0")
act = deepq.learn(
env,
network='mlp',
lr=1e-3,
total_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.02,
print_freq=10,
callback=callback
)
print("Saving model to cartpole_model.pkl")
act.save("cartpole_model.pkl")
if __name__ == '__main__':
main()