* update per-algorithm READMEs to reflect new way of running algorithms * adding a link to repo-wide README * updated README files and deepq.train_cartpole example
31 lines
646 B
Python
31 lines
646 B
Python
import gym
|
|
|
|
from baselines import deepq
|
|
|
|
|
|
def callback(lcl, _glb):
|
|
# stop training if reward exceeds 199
|
|
is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
|
|
return is_solved
|
|
|
|
|
|
def main():
|
|
env = gym.make("CartPole-v0")
|
|
act = deepq.learn(
|
|
env,
|
|
network='mlp',
|
|
lr=1e-3,
|
|
total_timesteps=100000,
|
|
buffer_size=50000,
|
|
exploration_fraction=0.1,
|
|
exploration_final_eps=0.02,
|
|
print_freq=10,
|
|
callback=callback
|
|
)
|
|
print("Saving model to cartpole_model.pkl")
|
|
act.save("cartpole_model.pkl")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|