diff --git a/baselines/run.py b/baselines/run.py index 5dee154..8ab71ac 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -121,9 +121,11 @@ def build_env(args): env = retro_wrappers.wrap_deepmind_retro(env) else: - get_session(tf.ConfigProto(allow_soft_placement=True, - intra_op_parallelism_threads=1, - inter_op_parallelism_threads=1)) + config = tf.ConfigProto(allow_soft_placement=True, + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1) + config.gpu_options.allow_growth = True + get_session(config=config) env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)