Merge pull request #52 from farbeiza/patch-1
Effectively apply weights from the replay buffer
This commit is contained in:
@@ -89,6 +89,7 @@ def learn(env,
|
|||||||
gamma=1.0,
|
gamma=1.0,
|
||||||
target_network_update_freq=500,
|
target_network_update_freq=500,
|
||||||
prioritized_replay=False,
|
prioritized_replay=False,
|
||||||
|
prioritized_importance_sampling=False,
|
||||||
prioritized_replay_alpha=0.6,
|
prioritized_replay_alpha=0.6,
|
||||||
prioritized_replay_beta0=0.4,
|
prioritized_replay_beta0=0.4,
|
||||||
prioritized_replay_beta_iters=None,
|
prioritized_replay_beta_iters=None,
|
||||||
@@ -232,7 +233,10 @@ def learn(env,
|
|||||||
else:
|
else:
|
||||||
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size)
|
obses_t, actions, rewards, obses_tp1, dones = replay_buffer.sample(batch_size)
|
||||||
weights, batch_idxes = np.ones_like(rewards), None
|
weights, batch_idxes = np.ones_like(rewards), None
|
||||||
td_errors = train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
|
if prioritized_importance_sampling:
|
||||||
|
td_errors = train(obses_t, actions, rewards, obses_tp1, dones, weights)
|
||||||
|
else:
|
||||||
|
td_errors = train(obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
|
||||||
if prioritized_replay:
|
if prioritized_replay:
|
||||||
new_priorities = np.abs(td_errors) + prioritized_replay_eps
|
new_priorities = np.abs(td_errors) + prioritized_replay_eps
|
||||||
replay_buffer.update_priorities(batch_idxes, new_priorities)
|
replay_buffer.update_priorities(batch_idxes, new_priorities)
|
||||||
|
Reference in New Issue
Block a user