From e2da7cd42f04d8952b4faa2906a29a9e29a09a0b Mon Sep 17 00:00:00 2001 From: Pim de Haan Date: Thu, 16 Aug 2018 12:08:53 -0700 Subject: [PATCH] Several bugfixes for #504, #505, #506 related to Classic Control and deepq (#507) * Several bugfixes * Fixed ActWrapper.step bug --- baselines/deepq/build_graph.py | 2 +- baselines/deepq/deepq.py | 6 ++++-- baselines/run.py | 10 +++++++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/baselines/deepq/build_graph.py b/baselines/deepq/build_graph.py index e9ff1a4..dd96f0e 100644 --- a/baselines/deepq/build_graph.py +++ b/baselines/deepq/build_graph.py @@ -309,7 +309,7 @@ def build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope="deepq", outputs=output_actions, givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False}, updates=updates) - def act(ob, reset, update_param_noise_threshold, update_param_noise_scale, stochastic=True, update_eps=-1): + def act(ob, reset=False, update_param_noise_threshold=False, update_param_noise_scale=False, stochastic=True, update_eps=-1): return _act(ob, stochastic, update_eps, reset, update_param_noise_threshold, update_param_noise_scale) return act diff --git a/baselines/deepq/deepq.py b/baselines/deepq/deepq.py index 7d44acf..01921bb 100644 --- a/baselines/deepq/deepq.py +++ b/baselines/deepq/deepq.py @@ -27,7 +27,7 @@ class ActWrapper(object): self.initial_state = None @staticmethod - def load_act(self, path): + def load_act(path): with open(path, "rb") as f: model_data, act_params = cloudpickle.load(f) act = deepq.build_act(**act_params) @@ -70,6 +70,7 @@ class ActWrapper(object): def save(self, path): save_state(path) + self.save_act(path+".pickle") def load_act(path): @@ -194,8 +195,9 @@ def learn(env, # capture the shape outside the closure so that the env object is not serialized # by cloudpickle when serializing make_obs_ph + observation_space = env.observation_space def make_obs_ph(name): - return ObservationInput(env.observation_space, name=name) + return ObservationInput(observation_space, name=name) act, train, update_target, debug = deepq.build_train( make_obs_ph=make_obs_ph, diff --git a/baselines/run.py b/baselines/run.py index cba8515..1491a5e 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -123,14 +123,18 @@ def build_env(args, render=False): env = bench.Monitor(env, logger.get_dir()) env = retro_wrappers.wrap_deepmind_retro(env) - elif env_type == 'classic': + elif env_type == 'classic_control': def make_env(): e = gym.make(env_id) + e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True) e.seed(seed) return e env = DummyVecEnv([make_env]) - + + else: + raise ValueError('Unknown env_type {}'.format(env_type)) + return env @@ -149,7 +153,7 @@ def get_env_type(env_id): return env_type, env_id def get_default_network(env_type): - if env_type == 'mujoco' or env_type=='classic': + if env_type == 'mujoco' or env_type == 'classic_control': return 'mlp' if env_type == 'atari': return 'cnn'