2017-05-17 14:41:46 -07:00
|
|
|
import gym
|
|
|
|
from baselines import deepq
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2017-05-25 14:40:26 -07:00
|
|
|
env = gym.make("PongNoFrameskip-v4")
|
2017-10-25 09:21:29 -04:00
|
|
|
env = deepq.wrap_atari_dqn(env)
|
2018-09-10 11:50:59 -07:00
|
|
|
model = deepq.learn(
|
|
|
|
env,
|
|
|
|
"conv_only",
|
|
|
|
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
|
|
|
hiddens=[256],
|
|
|
|
dueling=True,
|
|
|
|
total_timesteps=0
|
|
|
|
)
|
2017-05-17 14:41:46 -07:00
|
|
|
|
|
|
|
while True:
|
|
|
|
obs, done = env.reset(), False
|
|
|
|
episode_rew = 0
|
|
|
|
while not done:
|
|
|
|
env.render()
|
2018-09-10 11:50:59 -07:00
|
|
|
obs, rew, done, _ = env.step(model(obs[None])[0])
|
2017-05-17 14:41:46 -07:00
|
|
|
episode_rew += rew
|
|
|
|
print("Episode reward", episode_rew)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|