* Several bugfixes * Fixed ActWrapper.step bug
This commit is contained in:
@@ -309,7 +309,7 @@ def build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope="deepq",
|
|||||||
outputs=output_actions,
|
outputs=output_actions,
|
||||||
givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False},
|
givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False},
|
||||||
updates=updates)
|
updates=updates)
|
||||||
def act(ob, reset, update_param_noise_threshold, update_param_noise_scale, stochastic=True, update_eps=-1):
|
def act(ob, reset=False, update_param_noise_threshold=False, update_param_noise_scale=False, stochastic=True, update_eps=-1):
|
||||||
return _act(ob, stochastic, update_eps, reset, update_param_noise_threshold, update_param_noise_scale)
|
return _act(ob, stochastic, update_eps, reset, update_param_noise_threshold, update_param_noise_scale)
|
||||||
return act
|
return act
|
||||||
|
|
||||||
|
@@ -27,7 +27,7 @@ class ActWrapper(object):
|
|||||||
self.initial_state = None
|
self.initial_state = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_act(self, path):
|
def load_act(path):
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
model_data, act_params = cloudpickle.load(f)
|
model_data, act_params = cloudpickle.load(f)
|
||||||
act = deepq.build_act(**act_params)
|
act = deepq.build_act(**act_params)
|
||||||
@@ -70,6 +70,7 @@ class ActWrapper(object):
|
|||||||
|
|
||||||
def save(self, path):
|
def save(self, path):
|
||||||
save_state(path)
|
save_state(path)
|
||||||
|
self.save_act(path+".pickle")
|
||||||
|
|
||||||
|
|
||||||
def load_act(path):
|
def load_act(path):
|
||||||
@@ -194,8 +195,9 @@ def learn(env,
|
|||||||
# capture the shape outside the closure so that the env object is not serialized
|
# capture the shape outside the closure so that the env object is not serialized
|
||||||
# by cloudpickle when serializing make_obs_ph
|
# by cloudpickle when serializing make_obs_ph
|
||||||
|
|
||||||
|
observation_space = env.observation_space
|
||||||
def make_obs_ph(name):
|
def make_obs_ph(name):
|
||||||
return ObservationInput(env.observation_space, name=name)
|
return ObservationInput(observation_space, name=name)
|
||||||
|
|
||||||
act, train, update_target, debug = deepq.build_train(
|
act, train, update_target, debug = deepq.build_train(
|
||||||
make_obs_ph=make_obs_ph,
|
make_obs_ph=make_obs_ph,
|
||||||
|
@@ -123,14 +123,18 @@ def build_env(args, render=False):
|
|||||||
env = bench.Monitor(env, logger.get_dir())
|
env = bench.Monitor(env, logger.get_dir())
|
||||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||||
|
|
||||||
elif env_type == 'classic':
|
elif env_type == 'classic_control':
|
||||||
def make_env():
|
def make_env():
|
||||||
e = gym.make(env_id)
|
e = gym.make(env_id)
|
||||||
|
e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
|
||||||
e.seed(seed)
|
e.seed(seed)
|
||||||
return e
|
return e
|
||||||
|
|
||||||
env = DummyVecEnv([make_env])
|
env = DummyVecEnv([make_env])
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@@ -149,7 +153,7 @@ def get_env_type(env_id):
|
|||||||
return env_type, env_id
|
return env_type, env_id
|
||||||
|
|
||||||
def get_default_network(env_type):
|
def get_default_network(env_type):
|
||||||
if env_type == 'mujoco' or env_type=='classic':
|
if env_type == 'mujoco' or env_type == 'classic_control':
|
||||||
return 'mlp'
|
return 'mlp'
|
||||||
if env_type == 'atari':
|
if env_type == 'atari':
|
||||||
return 'cnn'
|
return 'cnn'
|
||||||
|
Reference in New Issue
Block a user