Final finishing touches

This commit is contained in:
Matthias Plappert
2017-07-26 19:13:45 -07:00
parent 6964729ed0
commit 45a1297ec0
2 changed files with 18 additions and 10 deletions

View File

@@ -43,6 +43,7 @@ def parse_args():
parser.add_argument("--target-update-freq", type=int, default=40000, help="number of iterations between every target network update")
parser.add_argument("--param-noise-update-freq", type=int, default=50, help="number of iterations between every re-scaling of the parameter noise")
parser.add_argument("--param-noise-reset-freq", type=int, default=10000, help="maximum number of steps to take per episode before re-perturbing the exploration policy")
parser.add_argument("--param-noise-threshold", type=float, default=0.05, help="the desired KL divergence between perturbed and non-perturbed policy. set to < 0 to use a KL divergence relative to the eps-greedy exploration")
# Bells and whistles
boolean_flag(parser, "double-q", default=True, help="whether or not to use double q learning")
boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model")
@@ -201,11 +202,14 @@ if __name__ == '__main__':
reset = True
update_eps = 0.01 # ensures that we cannot get stuck completely
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
# for detailed explanation.
update_param_noise_threshold = -np.log(1. - exploration.value(num_iters) + exploration.value(num_iters) / float(env.action_space.n))
if args.param_noise_threshold >= 0.:
update_param_noise_threshold = args.param_noise_threshold
else:
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
# for detailed explanation.
update_param_noise_threshold = -np.log(1. - exploration.value(num_iters) + exploration.value(num_iters) / float(env.action_space.n))
kwargs['reset'] = reset
kwargs['update_param_noise_threshold'] = update_param_noise_threshold
kwargs['update_param_noise_scale'] = (num_iters % args.param_noise_update_freq == 0)

View File

@@ -95,6 +95,7 @@ def learn(env,
prioritized_replay_eps=1e-6,
num_cpu=16,
param_noise=False,
param_noise_threshold=0.05,
callback=None):
"""Train a deepq model.
@@ -224,11 +225,14 @@ def learn(env,
update_param_noise_threshold = 0.
else:
update_eps = 0.
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
# for detailed explanation.
update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(env.action_space.n))
if param_noise_threshold >= 0.:
update_param_noise_threshold = param_noise_threshold
else:
# Compute the threshold such that the KL divergence between perturbed and non-perturbed
# policy is comparable to eps-greedy exploration with eps = exploration.value(t).
# See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017
# for detailed explanation.
update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(env.action_space.n))
kwargs['reset'] = reset
kwargs['update_param_noise_threshold'] = update_param_noise_threshold
kwargs['update_param_noise_scale'] = True