Merge branch 'master' of github.com:openai/baselines
This commit is contained in:
@@ -322,7 +322,9 @@ def save_state(fname, sess=None):
|
|||||||
from baselines import logger
|
from baselines import logger
|
||||||
logger.warn('save_state method is deprecated, please use save_variables instead')
|
logger.warn('save_state method is deprecated, please use save_variables instead')
|
||||||
sess = sess or get_session()
|
sess = sess or get_session()
|
||||||
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
dirname = os.path.dirname(fname)
|
||||||
|
if any(dirname):
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
saver.save(tf.get_default_session(), fname)
|
saver.save(tf.get_default_session(), fname)
|
||||||
|
|
||||||
@@ -335,7 +337,9 @@ def save_variables(save_path, variables=None, sess=None):
|
|||||||
|
|
||||||
ps = sess.run(variables)
|
ps = sess.run(variables)
|
||||||
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
dirname = os.path.dirname(save_path)
|
||||||
|
if any(dirname):
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
joblib.dump(save_dict, save_path)
|
joblib.dump(save_dict, save_path)
|
||||||
|
|
||||||
def load_variables(load_path, variables=None, sess=None):
|
def load_variables(load_path, variables=None, sess=None):
|
||||||
|
@@ -7,7 +7,7 @@ import cloudpickle
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
from baselines.common.tf_util import load_state, save_state
|
from baselines.common.tf_util import load_variables, save_variables
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
from baselines.common.schedules import LinearSchedule
|
from baselines.common.schedules import LinearSchedule
|
||||||
from baselines.common import set_global_seeds
|
from baselines.common import set_global_seeds
|
||||||
@@ -39,7 +39,7 @@ class ActWrapper(object):
|
|||||||
f.write(model_data)
|
f.write(model_data)
|
||||||
|
|
||||||
zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
|
zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
|
||||||
load_state(os.path.join(td, "model"))
|
load_variables(os.path.join(td, "model"))
|
||||||
|
|
||||||
return ActWrapper(act, act_params)
|
return ActWrapper(act, act_params)
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class ActWrapper(object):
|
|||||||
path = os.path.join(logger.get_dir(), "model.pkl")
|
path = os.path.join(logger.get_dir(), "model.pkl")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
with tempfile.TemporaryDirectory() as td:
|
||||||
save_state(os.path.join(td, "model"))
|
save_variables(os.path.join(td, "model"))
|
||||||
arc_name = os.path.join(td, "packed.zip")
|
arc_name = os.path.join(td, "packed.zip")
|
||||||
with zipfile.ZipFile(arc_name, 'w') as zipf:
|
with zipfile.ZipFile(arc_name, 'w') as zipf:
|
||||||
for root, dirs, files in os.walk(td):
|
for root, dirs, files in os.walk(td):
|
||||||
@@ -69,8 +69,7 @@ class ActWrapper(object):
|
|||||||
cloudpickle.dump((model_data, self._act_params), f)
|
cloudpickle.dump((model_data, self._act_params), f)
|
||||||
|
|
||||||
def save(self, path):
|
def save(self, path):
|
||||||
save_state(path)
|
save_variables(path)
|
||||||
self.save_act(path+".pickle")
|
|
||||||
|
|
||||||
|
|
||||||
def load_act(path):
|
def load_act(path):
|
||||||
@@ -249,11 +248,11 @@ def learn(env,
|
|||||||
model_saved = False
|
model_saved = False
|
||||||
|
|
||||||
if tf.train.latest_checkpoint(td) is not None:
|
if tf.train.latest_checkpoint(td) is not None:
|
||||||
load_state(model_file)
|
load_variables(model_file)
|
||||||
logger.log('Loaded model from {}'.format(model_file))
|
logger.log('Loaded model from {}'.format(model_file))
|
||||||
model_saved = True
|
model_saved = True
|
||||||
elif load_path is not None:
|
elif load_path is not None:
|
||||||
load_state(load_path)
|
load_variables(load_path)
|
||||||
logger.log('Loaded model from {}'.format(load_path))
|
logger.log('Loaded model from {}'.format(load_path))
|
||||||
|
|
||||||
|
|
||||||
@@ -322,12 +321,12 @@ def learn(env,
|
|||||||
if print_freq is not None:
|
if print_freq is not None:
|
||||||
logger.log("Saving model due to mean reward increase: {} -> {}".format(
|
logger.log("Saving model due to mean reward increase: {} -> {}".format(
|
||||||
saved_mean_reward, mean_100ep_reward))
|
saved_mean_reward, mean_100ep_reward))
|
||||||
save_state(model_file)
|
save_variables(model_file)
|
||||||
model_saved = True
|
model_saved = True
|
||||||
saved_mean_reward = mean_100ep_reward
|
saved_mean_reward = mean_100ep_reward
|
||||||
if model_saved:
|
if model_saved:
|
||||||
if print_freq is not None:
|
if print_freq is not None:
|
||||||
logger.log("Restored model with mean reward: {}".format(saved_mean_reward))
|
logger.log("Restored model with mean reward: {}".format(saved_mean_reward))
|
||||||
load_state(model_file)
|
load_variables(model_file)
|
||||||
|
|
||||||
return act
|
return act
|
||||||
|
@@ -5,7 +5,7 @@ from baselines import deepq
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = gym.make("CartPole-v0")
|
env = gym.make("CartPole-v0")
|
||||||
act = deepq.load("cartpole_model.pkl")
|
act = deepq.learn(env, network='mlp', total_timesteps=0, load_path="cartpole_model.pkl")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
obs, done = env.reset(), False
|
obs, done = env.reset(), False
|
||||||
|
@@ -1,11 +1,17 @@
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
from baselines import deepq
|
from baselines import deepq
|
||||||
|
from baselines.common import models
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = gym.make("MountainCar-v0")
|
env = gym.make("MountainCar-v0")
|
||||||
act = deepq.load("mountaincar_model.pkl")
|
act = deepq.learn(
|
||||||
|
env,
|
||||||
|
network=models.mlp(num_layers=1, num_hidden=64),
|
||||||
|
total_timesteps=0,
|
||||||
|
load_path='mountaincar_model.pkl'
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
obs, done = env.reset(), False
|
obs, done = env.reset(), False
|
||||||
|
@@ -5,14 +5,21 @@ from baselines import deepq
|
|||||||
def main():
|
def main():
|
||||||
env = gym.make("PongNoFrameskip-v4")
|
env = gym.make("PongNoFrameskip-v4")
|
||||||
env = deepq.wrap_atari_dqn(env)
|
env = deepq.wrap_atari_dqn(env)
|
||||||
act = deepq.load("pong_model.pkl")
|
model = deepq.learn(
|
||||||
|
env,
|
||||||
|
"conv_only",
|
||||||
|
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
||||||
|
hiddens=[256],
|
||||||
|
dueling=True,
|
||||||
|
total_timesteps=0
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
obs, done = env.reset(), False
|
obs, done = env.reset(), False
|
||||||
episode_rew = 0
|
episode_rew = 0
|
||||||
while not done:
|
while not done:
|
||||||
env.render()
|
env.render()
|
||||||
obs, rew, done, _ = env.step(act(obs[None])[0])
|
obs, rew, done, _ = env.step(model(obs[None])[0])
|
||||||
episode_rew += rew
|
episode_rew += rew
|
||||||
print("Episode reward", episode_rew)
|
print("Episode reward", episode_rew)
|
||||||
|
|
||||||
|
@@ -1,34 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from baselines import deepq
|
|
||||||
from baselines.common import retro_wrappers
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--env', help='environment ID', default='SuperMarioBros-Nes')
|
|
||||||
parser.add_argument('--gamestate', help='game state to load', default='Level1-1')
|
|
||||||
parser.add_argument('--model', help='model pickle file from ActWrapper.save', default='model.pkl')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
env = retro_wrappers.make_retro(game=args.env, state=args.gamestate, max_episode_steps=None)
|
|
||||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
|
||||||
act = deepq.load(args.model)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
obs, done = env.reset(), False
|
|
||||||
episode_rew = 0
|
|
||||||
while not done:
|
|
||||||
env.render()
|
|
||||||
action = act(obs[None])[0]
|
|
||||||
env_action = np.zeros(env.action_space.n)
|
|
||||||
env_action[action] = 1
|
|
||||||
obs, rew, done, _ = env.step(env_action)
|
|
||||||
episode_rew += rew
|
|
||||||
print('Episode reward', episode_rew)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@@ -1,52 +0,0 @@
|
|||||||
from baselines import deepq
|
|
||||||
from baselines.common import set_global_seeds
|
|
||||||
from baselines import bench
|
|
||||||
import argparse
|
|
||||||
from baselines import logger
|
|
||||||
from baselines.common.atari_wrappers import make_atari
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
|
|
||||||
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
|
||||||
parser.add_argument('--prioritized', type=int, default=1)
|
|
||||||
parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
|
|
||||||
parser.add_argument('--dueling', type=int, default=1)
|
|
||||||
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
|
|
||||||
parser.add_argument('--checkpoint-freq', type=int, default=10000)
|
|
||||||
parser.add_argument('--checkpoint-path', type=str, default=None)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
logger.configure()
|
|
||||||
set_global_seeds(args.seed)
|
|
||||||
env = make_atari(args.env)
|
|
||||||
env = bench.Monitor(env, logger.get_dir())
|
|
||||||
env = deepq.wrap_atari_dqn(env)
|
|
||||||
|
|
||||||
deepq.learn(
|
|
||||||
env,
|
|
||||||
"conv_only",
|
|
||||||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
|
||||||
hiddens=[256],
|
|
||||||
dueling=bool(args.dueling),
|
|
||||||
lr=1e-4,
|
|
||||||
total_timesteps=args.num_timesteps,
|
|
||||||
buffer_size=10000,
|
|
||||||
exploration_fraction=0.1,
|
|
||||||
exploration_final_eps=0.01,
|
|
||||||
train_freq=4,
|
|
||||||
learning_starts=10000,
|
|
||||||
target_network_update_freq=1000,
|
|
||||||
gamma=0.99,
|
|
||||||
prioritized_replay=bool(args.prioritized),
|
|
||||||
prioritized_replay_alpha=args.prioritized_replay_alpha,
|
|
||||||
checkpoint_freq=args.checkpoint_freq,
|
|
||||||
checkpoint_path=args.checkpoint_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@@ -1,49 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from baselines import deepq
|
|
||||||
from baselines.common import set_global_seeds
|
|
||||||
from baselines import bench
|
|
||||||
from baselines import logger
|
|
||||||
from baselines.common import retro_wrappers
|
|
||||||
import retro
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
parser.add_argument('--env', help='environment ID', default='SuperMarioBros-Nes')
|
|
||||||
parser.add_argument('--gamestate', help='game state to load', default='Level1-1')
|
|
||||||
parser.add_argument('--seed', help='seed', type=int, default=0)
|
|
||||||
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
|
|
||||||
args = parser.parse_args()
|
|
||||||
logger.configure()
|
|
||||||
set_global_seeds(args.seed)
|
|
||||||
env = retro_wrappers.make_retro(game=args.env, state=args.gamestate, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE)
|
|
||||||
env.seed(args.seed)
|
|
||||||
env = bench.Monitor(env, logger.get_dir())
|
|
||||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
|
||||||
|
|
||||||
model = deepq.models.cnn_to_mlp(
|
|
||||||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
|
||||||
hiddens=[256],
|
|
||||||
dueling=True
|
|
||||||
)
|
|
||||||
act = deepq.learn(
|
|
||||||
env,
|
|
||||||
q_func=model,
|
|
||||||
lr=1e-4,
|
|
||||||
max_timesteps=args.num_timesteps,
|
|
||||||
buffer_size=10000,
|
|
||||||
exploration_fraction=0.1,
|
|
||||||
exploration_final_eps=0.01,
|
|
||||||
train_freq=4,
|
|
||||||
learning_starts=10000,
|
|
||||||
target_network_update_freq=1000,
|
|
||||||
gamma=0.99,
|
|
||||||
prioritized_replay=True
|
|
||||||
)
|
|
||||||
act.save()
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@@ -1,17 +1,17 @@
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
from baselines import deepq
|
from baselines import deepq
|
||||||
|
from baselines.common import models
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
env = gym.make("MountainCar-v0")
|
env = gym.make("MountainCar-v0")
|
||||||
# Enabling layer_norm here is import for parameter space noise!
|
# Enabling layer_norm here is import for parameter space noise!
|
||||||
model = deepq.models.mlp([64], layer_norm=True)
|
|
||||||
act = deepq.learn(
|
act = deepq.learn(
|
||||||
env,
|
env,
|
||||||
q_func=model,
|
network=models.mlp(num_hidden=64, num_layers=1),
|
||||||
lr=1e-3,
|
lr=1e-3,
|
||||||
max_timesteps=100000,
|
total_timesteps=100000,
|
||||||
buffer_size=50000,
|
buffer_size=50000,
|
||||||
exploration_fraction=0.1,
|
exploration_fraction=0.1,
|
||||||
exploration_final_eps=0.1,
|
exploration_final_eps=0.1,
|
||||||
|
36
baselines/deepq/experiments/train_pong.py
Normal file
36
baselines/deepq/experiments/train_pong.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from baselines import deepq
|
||||||
|
from baselines import bench
|
||||||
|
from baselines import logger
|
||||||
|
from baselines.common.atari_wrappers import make_atari
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logger.configure()
|
||||||
|
env = make_atari('PongNoFrameskip-v4')
|
||||||
|
env = bench.Monitor(env, logger.get_dir())
|
||||||
|
env = deepq.wrap_atari_dqn(env)
|
||||||
|
|
||||||
|
model = deepq.learn(
|
||||||
|
env,
|
||||||
|
"conv_only",
|
||||||
|
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
||||||
|
hiddens=[256],
|
||||||
|
dueling=True,
|
||||||
|
lr=1e-4,
|
||||||
|
total_timesteps=int(1e7),
|
||||||
|
buffer_size=10000,
|
||||||
|
exploration_fraction=0.1,
|
||||||
|
exploration_final_eps=0.01,
|
||||||
|
train_freq=4,
|
||||||
|
learning_starts=10000,
|
||||||
|
target_network_update_freq=1000,
|
||||||
|
gamma=0.99,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.save('pong_model.pkl')
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Reference in New Issue
Block a user