baselines issue #564 (#574)

* fixes to enjoy_cartpole, enjoy_mountaincar.py

* fixed {train,enjoy}_pong, removed enjoy_retro

* set number of timesteps to 1e7 in train_pong

* flake8 complaints

* use synchronous version fo acktr in test_env_after_learn

* flake8
This commit is contained in:
pzhokhov
2018-09-10 11:50:59 -07:00
committed by GitHub
parent 8614c4ddbf
commit adaa8aefa8
10 changed files with 70 additions and 153 deletions

View File

@@ -322,7 +322,9 @@ def save_state(fname, sess=None):
from baselines import logger
logger.warn('save_state method is deprecated, please use save_variables instead')
sess = sess or get_session()
os.makedirs(os.path.dirname(fname), exist_ok=True)
dirname = os.path.dirname(fname)
if any(dirname):
os.makedirs(dirname, exist_ok=True)
saver = tf.train.Saver()
saver.save(tf.get_default_session(), fname)
@@ -335,7 +337,9 @@ def save_variables(save_path, variables=None, sess=None):
ps = sess.run(variables)
save_dict = {v.name: value for v, value in zip(variables, ps)}
os.makedirs(os.path.dirname(save_path), exist_ok=True)
dirname = os.path.dirname(save_path)
if any(dirname):
os.makedirs(dirname, exist_ok=True)
joblib.dump(save_dict, save_path)
def load_variables(load_path, variables=None, sess=None):

View File

@@ -7,7 +7,7 @@ import cloudpickle
import numpy as np
import baselines.common.tf_util as U
from baselines.common.tf_util import load_state, save_state
from baselines.common.tf_util import load_variables, save_variables
from baselines import logger
from baselines.common.schedules import LinearSchedule
from baselines.common import set_global_seeds
@@ -39,7 +39,7 @@ class ActWrapper(object):
f.write(model_data)
zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
load_state(os.path.join(td, "model"))
load_variables(os.path.join(td, "model"))
return ActWrapper(act, act_params)
@@ -55,7 +55,7 @@ class ActWrapper(object):
path = os.path.join(logger.get_dir(), "model.pkl")
with tempfile.TemporaryDirectory() as td:
save_state(os.path.join(td, "model"))
save_variables(os.path.join(td, "model"))
arc_name = os.path.join(td, "packed.zip")
with zipfile.ZipFile(arc_name, 'w') as zipf:
for root, dirs, files in os.walk(td):
@@ -69,8 +69,7 @@ class ActWrapper(object):
cloudpickle.dump((model_data, self._act_params), f)
def save(self, path):
save_state(path)
self.save_act(path+".pickle")
save_variables(path)
def load_act(path):
@@ -249,11 +248,11 @@ def learn(env,
model_saved = False
if tf.train.latest_checkpoint(td) is not None:
load_state(model_file)
load_variables(model_file)
logger.log('Loaded model from {}'.format(model_file))
model_saved = True
elif load_path is not None:
load_state(load_path)
load_variables(load_path)
logger.log('Loaded model from {}'.format(load_path))
@@ -322,12 +321,12 @@ def learn(env,
if print_freq is not None:
logger.log("Saving model due to mean reward increase: {} -> {}".format(
saved_mean_reward, mean_100ep_reward))
save_state(model_file)
save_variables(model_file)
model_saved = True
saved_mean_reward = mean_100ep_reward
if model_saved:
if print_freq is not None:
logger.log("Restored model with mean reward: {}".format(saved_mean_reward))
load_state(model_file)
load_variables(model_file)
return act

View File

@@ -5,7 +5,7 @@ from baselines import deepq
def main():
env = gym.make("CartPole-v0")
act = deepq.load("cartpole_model.pkl")
act = deepq.learn(env, network='mlp', total_timesteps=0, load_path="cartpole_model.pkl")
while True:
obs, done = env.reset(), False

View File

@@ -1,11 +1,17 @@
import gym
from baselines import deepq
from baselines.common import models
def main():
env = gym.make("MountainCar-v0")
act = deepq.load("mountaincar_model.pkl")
act = deepq.learn(
env,
network=models.mlp(num_layers=1, num_hidden=64),
total_timesteps=0,
load_path='mountaincar_model.pkl'
)
while True:
obs, done = env.reset(), False

View File

@@ -5,14 +5,21 @@ from baselines import deepq
def main():
env = gym.make("PongNoFrameskip-v4")
env = deepq.wrap_atari_dqn(env)
act = deepq.load("pong_model.pkl")
model = deepq.learn(
env,
"conv_only",
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
total_timesteps=0
)
while True:
obs, done = env.reset(), False
episode_rew = 0
while not done:
env.render()
obs, rew, done, _ = env.step(act(obs[None])[0])
obs, rew, done, _ = env.step(model(obs[None])[0])
episode_rew += rew
print("Episode reward", episode_rew)

View File

@@ -1,34 +0,0 @@
import argparse
import numpy as np
from baselines import deepq
from baselines.common import retro_wrappers
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--env', help='environment ID', default='SuperMarioBros-Nes')
parser.add_argument('--gamestate', help='game state to load', default='Level1-1')
parser.add_argument('--model', help='model pickle file from ActWrapper.save', default='model.pkl')
args = parser.parse_args()
env = retro_wrappers.make_retro(game=args.env, state=args.gamestate, max_episode_steps=None)
env = retro_wrappers.wrap_deepmind_retro(env)
act = deepq.load(args.model)
while True:
obs, done = env.reset(), False
episode_rew = 0
while not done:
env.render()
action = act(obs[None])[0]
env_action = np.zeros(env.action_space.n)
env_action[action] = 1
obs, rew, done, _ = env.step(env_action)
episode_rew += rew
print('Episode reward', episode_rew)
if __name__ == '__main__':
main()

View File

@@ -1,52 +0,0 @@
from baselines import deepq
from baselines.common import set_global_seeds
from baselines import bench
import argparse
from baselines import logger
from baselines.common.atari_wrappers import make_atari
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--prioritized', type=int, default=1)
parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
parser.add_argument('--dueling', type=int, default=1)
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
parser.add_argument('--checkpoint-freq', type=int, default=10000)
parser.add_argument('--checkpoint-path', type=str, default=None)
args = parser.parse_args()
logger.configure()
set_global_seeds(args.seed)
env = make_atari(args.env)
env = bench.Monitor(env, logger.get_dir())
env = deepq.wrap_atari_dqn(env)
deepq.learn(
env,
"conv_only",
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=bool(args.dueling),
lr=1e-4,
total_timesteps=args.num_timesteps,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=bool(args.prioritized),
prioritized_replay_alpha=args.prioritized_replay_alpha,
checkpoint_freq=args.checkpoint_freq,
checkpoint_path=args.checkpoint_path,
)
env.close()
if __name__ == '__main__':
main()

View File

@@ -1,49 +0,0 @@
import argparse
from baselines import deepq
from baselines.common import set_global_seeds
from baselines import bench
from baselines import logger
from baselines.common import retro_wrappers
import retro
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--env', help='environment ID', default='SuperMarioBros-Nes')
parser.add_argument('--gamestate', help='game state to load', default='Level1-1')
parser.add_argument('--seed', help='seed', type=int, default=0)
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
args = parser.parse_args()
logger.configure()
set_global_seeds(args.seed)
env = retro_wrappers.make_retro(game=args.env, state=args.gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE)
env.seed(args.seed)
env = bench.Monitor(env, logger.get_dir())
env = retro_wrappers.wrap_deepmind_retro(env)
model = deepq.models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True
)
act = deepq.learn(
env,
q_func=model,
lr=1e-4,
max_timesteps=args.num_timesteps,
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
prioritized_replay=True
)
act.save()
env.close()
if __name__ == '__main__':
main()

View File

@@ -1,17 +1,17 @@
import gym
from baselines import deepq
from baselines.common import models
def main():
env = gym.make("MountainCar-v0")
# Enabling layer_norm here is import for parameter space noise!
model = deepq.models.mlp([64], layer_norm=True)
act = deepq.learn(
env,
q_func=model,
network=models.mlp(num_hidden=64, num_layers=1),
lr=1e-3,
max_timesteps=100000,
total_timesteps=100000,
buffer_size=50000,
exploration_fraction=0.1,
exploration_final_eps=0.1,

View File

@@ -0,0 +1,36 @@
from baselines import deepq
from baselines import bench
from baselines import logger
from baselines.common.atari_wrappers import make_atari
def main():
logger.configure()
env = make_atari('PongNoFrameskip-v4')
env = bench.Monitor(env, logger.get_dir())
env = deepq.wrap_atari_dqn(env)
model = deepq.learn(
env,
"conv_only",
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256],
dueling=True,
lr=1e-4,
total_timesteps=int(1e7),
buffer_size=10000,
exploration_fraction=0.1,
exploration_final_eps=0.01,
train_freq=4,
learning_starts=10000,
target_network_update_freq=1000,
gamma=0.99,
)
model.save('pong_model.pkl')
env.close()
if __name__ == '__main__':
main()