Merge branch 'master' of github.com:openai/baselines into peterz_update_READMEs

This commit is contained in:
Peter Zhokhov
2018-08-16 12:26:51 -07:00
3 changed files with 11 additions and 5 deletions

View File

@@ -309,7 +309,7 @@ def build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope="deepq",
outputs=output_actions, outputs=output_actions,
givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False}, givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False},
updates=updates) updates=updates)
def act(ob, reset, update_param_noise_threshold, update_param_noise_scale, stochastic=True, update_eps=-1): def act(ob, reset=False, update_param_noise_threshold=False, update_param_noise_scale=False, stochastic=True, update_eps=-1):
return _act(ob, stochastic, update_eps, reset, update_param_noise_threshold, update_param_noise_scale) return _act(ob, stochastic, update_eps, reset, update_param_noise_threshold, update_param_noise_scale)
return act return act

View File

@@ -70,6 +70,7 @@ class ActWrapper(object):
def save(self, path): def save(self, path):
save_state(path) save_state(path)
self.save_act(path+".pickle")
def load_act(path): def load_act(path):
@@ -194,8 +195,9 @@ def learn(env,
# capture the shape outside the closure so that the env object is not serialized # capture the shape outside the closure so that the env object is not serialized
# by cloudpickle when serializing make_obs_ph # by cloudpickle when serializing make_obs_ph
observation_space = env.observation_space
def make_obs_ph(name): def make_obs_ph(name):
return ObservationInput(env.observation_space, name=name) return ObservationInput(observation_space, name=name)
act, train, update_target, debug = deepq.build_train( act, train, update_target, debug = deepq.build_train(
make_obs_ph=make_obs_ph, make_obs_ph=make_obs_ph,

View File

@@ -123,14 +123,18 @@ def build_env(args, render=False):
env = bench.Monitor(env, logger.get_dir()) env = bench.Monitor(env, logger.get_dir())
env = retro_wrappers.wrap_deepmind_retro(env) env = retro_wrappers.wrap_deepmind_retro(env)
elif env_type == 'classic': elif env_type == 'classic_control':
def make_env(): def make_env():
e = gym.make(env_id) e = gym.make(env_id)
e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
e.seed(seed) e.seed(seed)
return e return e
env = DummyVecEnv([make_env]) env = DummyVecEnv([make_env])
else:
raise ValueError('Unknown env_type {}'.format(env_type))
return env return env
@@ -149,7 +153,7 @@ def get_env_type(env_id):
return env_type, env_id return env_type, env_id
def get_default_network(env_type): def get_default_network(env_type):
if env_type == 'mujoco' or env_type=='classic': if env_type == 'mujoco' or env_type == 'classic_control':
return 'mlp' return 'mlp'
if env_type == 'atari': if env_type == 'atari':
return 'cnn' return 'cnn'