diff --git a/baselines/deepq/simple.py b/baselines/deepq/simple.py index 533070c..e349f41 100644 --- a/baselines/deepq/simple.py +++ b/baselines/deepq/simple.py @@ -89,6 +89,7 @@ def learn(env, gamma=1.0, target_network_update_freq=500, prioritized_replay=False, + prioritized_importance_sampling=False, prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None, @@ -232,7 +233,10 @@ def learn(env, else: obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size) weights, batch_idxes = np.ones_like(rewards), None - td_errors = train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards)) + if prioritized_importance_sampling: + td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights) + else: + td_errors = train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards)) if prioritized_replay: new_priorities = np.abs(td_errors) + prioritized_replay_eps replay_buffer.update_priorities(batch_idxes, new_priorities)