diff --git a/baselines/deepq/experiments/atari/train.py b/baselines/deepq/experiments/atari/train.py index 223a6dc..c052ef0 100644 --- a/baselines/deepq/experiments/atari/train.py +++ b/baselines/deepq/experiments/atari/train.py @@ -43,6 +43,7 @@ def parse_args(): parser.add_argument("--target-update-freq", type=int, default=40000, help="number of iterations between every target network update") parser.add_argument("--param-noise-update-freq", type=int, default=50, help="number of iterations between every re-scaling of the parameter noise") parser.add_argument("--param-noise-reset-freq", type=int, default=10000, help="maximum number of steps to take per episode before re-perturbing the exploration policy") + parser.add_argument("--param-noise-threshold", type=float, default=0.05, help="the desired KL divergence between perturbed and non-perturbed policy. set to < 0 to use a KL divergence relative to the eps-greedy exploration") # Bells and whistles boolean_flag(parser, "double-q", default=True, help="whether or not to use double q learning") boolean_flag(parser, "dueling", default=False, help="whether or not to use dueling model") @@ -201,11 +202,14 @@ if __name__ == '__main__': reset = True update_eps = 0.01 # ensures that we cannot get stuck completely - # Compute the threshold such that the KL divergence between perturbed and non-perturbed - # policy is comparable to eps-greedy exploration with eps = exploration.value(t). - # See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017 - # for detailed explanation. - update_param_noise_threshold = -np.log(1. - exploration.value(num_iters) + exploration.value(num_iters) / float(env.action_space.n)) + if args.param_noise_threshold >= 0.: + update_param_noise_threshold = args.param_noise_threshold + else: + # Compute the threshold such that the KL divergence between perturbed and non-perturbed + # policy is comparable to eps-greedy exploration with eps = exploration.value(t). + # See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017 + # for detailed explanation. + update_param_noise_threshold = -np.log(1. - exploration.value(num_iters) + exploration.value(num_iters) / float(env.action_space.n)) kwargs['reset'] = reset kwargs['update_param_noise_threshold'] = update_param_noise_threshold kwargs['update_param_noise_scale'] = (num_iters % args.param_noise_update_freq == 0) diff --git a/baselines/deepq/simple.py b/baselines/deepq/simple.py index 052d95b..4a9f710 100644 --- a/baselines/deepq/simple.py +++ b/baselines/deepq/simple.py @@ -95,6 +95,7 @@ def learn(env, prioritized_replay_eps=1e-6, num_cpu=16, param_noise=False, + param_noise_threshold=0.05, callback=None): """Train a deepq model. @@ -224,11 +225,14 @@ def learn(env, update_param_noise_threshold = 0. else: update_eps = 0. - # Compute the threshold such that the KL divergence between perturbed and non-perturbed - # policy is comparable to eps-greedy exploration with eps = exploration.value(t). - # See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017 - # for detailed explanation. - update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(env.action_space.n)) + if param_noise_threshold >= 0.: + update_param_noise_threshold = param_noise_threshold + else: + # Compute the threshold such that the KL divergence between perturbed and non-perturbed + # policy is comparable to eps-greedy exploration with eps = exploration.value(t). + # See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017 + # for detailed explanation. + update_param_noise_threshold = -np.log(1. - exploration.value(t) + exploration.value(t) / float(env.action_space.n)) kwargs['reset'] = reset kwargs['update_param_noise_threshold'] = update_param_noise_threshold kwargs['update_param_noise_scale'] = True