Fix atari wrapper (affecting a2c perf) and pposgd mujoco performance
- removed vf clipping in pposgd - that was severely degrading performance on mujoco because it didn’t account for scale of returns - switched adam epsilon in pposgd_simple - brought back no-ops in atari wrapper (oops) - added readmes - revamped run_X_benchmark scripts to have standard form - cleaned up DDPG a little, removed deprecated SimpleMonitor and non-idiomatic usage of logger
This commit is contained in:
5
baselines/a2c/README.md
Normal file
5
baselines/a2c/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# A2C
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1602.01783
|
||||||
|
- Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
|
||||||
|
- `python -m baselines.a2c.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options.
|
@@ -183,26 +183,5 @@ def learn(policy, env, seed, nsteps=5, nstack=4, total_timesteps=int(80e6), vf_c
|
|||||||
logger.dump_tabular()
|
logger.dump_tabular()
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
def main():
|
|
||||||
env_id = 'SpaceInvaders'
|
|
||||||
seed = 42
|
|
||||||
nenvs = 4
|
|
||||||
|
|
||||||
def make_env(rank):
|
|
||||||
def env_fn():
|
|
||||||
env = gym.make('{}NoFrameskip-v4'.format(env_id))
|
|
||||||
env.seed(seed + rank)
|
|
||||||
if logger.get_dir():
|
|
||||||
from baselines import bench
|
|
||||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "{}.monitor.json".format(rank)))
|
|
||||||
gym.logger.setLevel(logging.WARN)
|
|
||||||
return wrap_deepmind(env)
|
|
||||||
return env_fn
|
|
||||||
|
|
||||||
set_global_seeds(seed)
|
|
||||||
env = SubprocVecEnv([make_env(i) for i in range(nenvs)])
|
|
||||||
policy = CnnPolicy
|
|
||||||
learn(policy, env, seed)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
@@ -121,87 +121,3 @@ class CnnPolicy(object):
|
|||||||
self.vf = vf
|
self.vf = vf
|
||||||
self.step = step
|
self.step = step
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
class AcerCnnPolicy(object):
|
|
||||||
|
|
||||||
def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False):
|
|
||||||
nbatch = nenv * nsteps
|
|
||||||
nh, nw, nc = ob_space.shape
|
|
||||||
ob_shape = (nbatch, nh, nw, nc * nstack)
|
|
||||||
nact = ac_space.n
|
|
||||||
X = tf.placeholder(tf.uint8, ob_shape) # obs
|
|
||||||
with tf.variable_scope("model", reuse=reuse):
|
|
||||||
h = conv(tf.cast(X, tf.float32) / 255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))
|
|
||||||
h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))
|
|
||||||
h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))
|
|
||||||
h3 = conv_to_fc(h3)
|
|
||||||
h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))
|
|
||||||
pi_logits = fc(h4, 'pi', nact, act=lambda x: x, init_scale=0.01)
|
|
||||||
pi = tf.nn.softmax(pi_logits)
|
|
||||||
q = fc(h4, 'q', nact, act=lambda x: x)
|
|
||||||
|
|
||||||
a = sample(pi_logits) # could change this to use self.pi instead
|
|
||||||
self.initial_state = [] # not stateful
|
|
||||||
self.X = X
|
|
||||||
self.pi = pi # actual policy params now
|
|
||||||
self.q = q
|
|
||||||
|
|
||||||
def step(ob, *args, **kwargs):
|
|
||||||
# returns actions, mus, states
|
|
||||||
a0, pi0 = sess.run([a, pi], {X: ob})
|
|
||||||
return a0, pi0, [] # dummy state
|
|
||||||
|
|
||||||
def out(ob, *args, **kwargs):
|
|
||||||
pi0, q0 = sess.run([pi, q], {X: ob})
|
|
||||||
return pi0, q0
|
|
||||||
|
|
||||||
def act(ob, *args, **kwargs):
|
|
||||||
return sess.run(a, {X: ob})
|
|
||||||
|
|
||||||
self.step = step
|
|
||||||
self.out = out
|
|
||||||
self.act = act
|
|
||||||
|
|
||||||
class AcerLstmPolicy(object):
|
|
||||||
|
|
||||||
def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False, nlstm=256):
|
|
||||||
nbatch = nenv * nsteps
|
|
||||||
nh, nw, nc = ob_space.shape
|
|
||||||
ob_shape = (nbatch, nh, nw, nc * nstack)
|
|
||||||
nact = ac_space.n
|
|
||||||
X = tf.placeholder(tf.uint8, ob_shape) # obs
|
|
||||||
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
|
|
||||||
S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
|
|
||||||
with tf.variable_scope("model", reuse=reuse):
|
|
||||||
h = conv(tf.cast(X, tf.float32) / 255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))
|
|
||||||
h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))
|
|
||||||
h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))
|
|
||||||
h3 = conv_to_fc(h3)
|
|
||||||
h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))
|
|
||||||
|
|
||||||
# lstm
|
|
||||||
xs = batch_to_seq(h4, nenv, nsteps)
|
|
||||||
ms = batch_to_seq(M, nenv, nsteps)
|
|
||||||
h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
|
|
||||||
h5 = seq_to_batch(h5)
|
|
||||||
|
|
||||||
pi_logits = fc(h5, 'pi', nact, act=lambda x: x, init_scale=0.01)
|
|
||||||
pi = tf.nn.softmax(pi_logits)
|
|
||||||
q = fc(h5, 'q', nact, act=lambda x: x)
|
|
||||||
|
|
||||||
a = sample(pi_logits) # could change this to use self.pi instead
|
|
||||||
self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)
|
|
||||||
self.X = X
|
|
||||||
self.M = M
|
|
||||||
self.S = S
|
|
||||||
self.pi = pi # actual policy params now
|
|
||||||
self.q = q
|
|
||||||
|
|
||||||
def step(ob, state, mask, *args, **kwargs):
|
|
||||||
# returns actions, mus, states
|
|
||||||
a0, pi0, s = sess.run([a, pi, snew], {X: ob, S: state, M: mask})
|
|
||||||
return a0, pi0, s
|
|
||||||
|
|
||||||
self.step = step
|
|
||||||
|
|
||||||
# For Mujoco. Taken from PPOSGD
|
|
@@ -8,21 +8,20 @@ from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|||||||
from baselines.common.atari_wrappers import wrap_deepmind
|
from baselines.common.atari_wrappers import wrap_deepmind
|
||||||
from baselines.a2c.policies import CnnPolicy, LstmPolicy, LnLstmPolicy
|
from baselines.a2c.policies import CnnPolicy, LstmPolicy, LnLstmPolicy
|
||||||
|
|
||||||
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
|
def train(env_id, num_frames, seed, policy, lrschedule, num_cpu):
|
||||||
num_timesteps //= 4
|
num_timesteps = int(num_frames / 4 * 1.1)
|
||||||
|
# divide by 4 due to frameskip, then do a little extras so episodes end
|
||||||
def make_env(rank):
|
def make_env(rank):
|
||||||
def _thunk():
|
def _thunk():
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
env.seed(seed + rank)
|
env.seed(seed + rank)
|
||||||
env = bench.Monitor(env, os.path.join(logger.get_dir(), "{}.monitor.json".format(rank)))
|
env = bench.Monitor(env, logger.get_dir() and
|
||||||
|
os.path.join(logger.get_dir(), "{}.monitor.json".format(rank)))
|
||||||
gym.logger.setLevel(logging.WARN)
|
gym.logger.setLevel(logging.WARN)
|
||||||
return wrap_deepmind(env)
|
return wrap_deepmind(env)
|
||||||
return _thunk
|
return _thunk
|
||||||
|
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
|
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
|
||||||
|
|
||||||
if policy == 'cnn':
|
if policy == 'cnn':
|
||||||
policy_fn = CnnPolicy
|
policy_fn = CnnPolicy
|
||||||
elif policy == 'lstm':
|
elif policy == 'lstm':
|
||||||
@@ -32,10 +31,18 @@ def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
|
|||||||
learn(policy_fn, env, seed, total_timesteps=num_timesteps, lrschedule=lrschedule)
|
learn(policy_fn, env, seed, total_timesteps=num_timesteps, lrschedule=lrschedule)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
train('BreakoutNoFrameskip-v4', num_timesteps=int(40e6), seed=0, policy='cnn', lrschedule='linear', num_cpu=16)
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
|
||||||
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn')
|
||||||
|
parser.add_argument('--lrschedule', help='Learning rate schedule', choices=['constant', 'linear'], default='constant')
|
||||||
|
parser.add_argument('--million_frames', help='How many frames to train (/ 1e6). '
|
||||||
|
'This number gets divided by 4 due to frameskip', type=int, default=40)
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args.env, num_frames=1e6 * args.million_frames, seed=args.seed,
|
||||||
|
policy=args.policy, lrschedule=args.lrschedule, num_cpu=16)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
5
baselines/acktr/README.md
Normal file
5
baselines/acktr/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# ACKTR
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1708.05144
|
||||||
|
- Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/
|
||||||
|
- `python -m baselines.acktr.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options.
|
@@ -8,9 +8,8 @@ from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|||||||
from baselines.common.atari_wrappers import wrap_deepmind
|
from baselines.common.atari_wrappers import wrap_deepmind
|
||||||
from baselines.acktr.policies import CnnPolicy
|
from baselines.acktr.policies import CnnPolicy
|
||||||
|
|
||||||
def train(env_id, num_timesteps, seed, num_cpu):
|
def train(env_id, num_frames, seed, num_cpu):
|
||||||
num_timesteps //= 4
|
num_timesteps = int(num_frames / 4 * 1.1)
|
||||||
|
|
||||||
def make_env(rank):
|
def make_env(rank):
|
||||||
def _thunk():
|
def _thunk():
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
@@ -20,16 +19,21 @@ def train(env_id, num_timesteps, seed, num_cpu):
|
|||||||
gym.logger.setLevel(logging.WARN)
|
gym.logger.setLevel(logging.WARN)
|
||||||
return wrap_deepmind(env)
|
return wrap_deepmind(env)
|
||||||
return _thunk
|
return _thunk
|
||||||
|
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
|
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
|
||||||
|
|
||||||
policy_fn = CnnPolicy
|
policy_fn = CnnPolicy
|
||||||
learn(policy_fn, env, seed, total_timesteps=num_timesteps, nprocs=num_cpu)
|
learn(policy_fn, env, seed, total_timesteps=num_timesteps, nprocs=num_cpu)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
train('BreakoutNoFrameskip-v4', num_timesteps=int(40e6), seed=0, num_cpu=32)
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
|
||||||
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
parser.add_argument('--million_frames', help='How many frames to train (/ 1e6). '
|
||||||
|
'This number gets divided by 4 due to frameskip', type=int, default=40)
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args.env, num_frames=1e6 * args.million_frames, seed=args.seed, num_cpu=32)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@@ -35,8 +35,8 @@ def train(env_id, num_timesteps, seed):
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run Mujoco benchmark.')
|
parser = argparse.ArgumentParser(description='Run Mujoco benchmark.')
|
||||||
parser.add_argument('--env_id', type=str, default="Reacher-v1")
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
parser.add_argument('--env', help='environment ID', type=str, default="Reacher-v1")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args.env_id, num_timesteps=1e6, seed=1)
|
train(args.env_id, num_timesteps=1e6, seed=args.seed)
|
||||||
|
@@ -3,8 +3,8 @@ import numpy as np
|
|||||||
from baselines import common
|
from baselines import common
|
||||||
from baselines.common import tf_util as U
|
from baselines.common import tf_util as U
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import kfac
|
from baselines.acktr import kfac
|
||||||
from utils import dense
|
from baselines.acktr.utils import dense
|
||||||
|
|
||||||
class NeuralNetValueFunction(object):
|
class NeuralNetValueFunction(object):
|
||||||
def __init__(self, ob_dim, ac_dim): #pylint: disable=W0613
|
def __init__(self, ob_dim, ac_dim): #pylint: disable=W0613
|
||||||
|
@@ -162,7 +162,7 @@ def wrap_deepmind(env, episode_life=True, clip_rewards=True):
|
|||||||
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
|
||||||
if episode_life:
|
if episode_life:
|
||||||
env = EpisodicLifeEnv(env)
|
env = EpisodicLifeEnv(env)
|
||||||
# env = NoopResetEnv(env, noop_max=30)
|
env = NoopResetEnv(env, noop_max=30)
|
||||||
env = MaxAndSkipEnv(env, skip=4)
|
env = MaxAndSkipEnv(env, skip=4)
|
||||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||||
env = FireResetEnv(env)
|
env = FireResetEnv(env)
|
||||||
|
5
baselines/ddpg/README.md
Normal file
5
baselines/ddpg/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# DDPG
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1509.02971
|
||||||
|
- Baselines post: https://blog.openai.com/better-exploration-with-parameter-noise/
|
||||||
|
- `python -m baselines.ddpg.main` runs the algorithm for 1M frames = 10M timesteps on a Mujoco environment. See help (`-h`) for more options.
|
@@ -1,19 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
from tempfile import mkdtemp
|
import logging
|
||||||
import sys
|
from baselines import logger, bench
|
||||||
import subprocess
|
|
||||||
import threading
|
|
||||||
import json
|
|
||||||
|
|
||||||
from baselines.common.mpi_fork import mpi_fork
|
|
||||||
from baselines import logger
|
|
||||||
from baselines.logger import Logger
|
|
||||||
from baselines.common.misc_util import (
|
from baselines.common.misc_util import (
|
||||||
set_global_seeds,
|
set_global_seeds,
|
||||||
boolean_flag,
|
boolean_flag,
|
||||||
SimpleMonitor
|
|
||||||
)
|
)
|
||||||
import baselines.ddpg.training as training
|
import baselines.ddpg.training as training
|
||||||
from baselines.ddpg.models import Actor, Critic
|
from baselines.ddpg.models import Actor, Critic
|
||||||
@@ -24,42 +16,22 @@ import gym
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
|
||||||
|
def run(env_id, seed, noise_type, layer_norm, evaluation, **kwargs):
|
||||||
def run(env_id, seed, noise_type, num_cpu, layer_norm, logdir, gym_monitor, evaluation, bind_to_core, **kwargs):
|
|
||||||
kwargs['logdir'] = logdir
|
|
||||||
whoami = mpi_fork(num_cpu, bind_to_core=bind_to_core)
|
|
||||||
if whoami == 'parent':
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
# Configure things.
|
# Configure things.
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
if rank != 0:
|
if rank != 0: logger.set_level(logger.DISABLED)
|
||||||
# Write to temp directory for all non-master workers.
|
|
||||||
actual_dir = None
|
|
||||||
Logger.CURRENT.close()
|
|
||||||
Logger.CURRENT = Logger(dir=mkdtemp(), output_formats=[])
|
|
||||||
logger.set_level(logger.DISABLED)
|
|
||||||
|
|
||||||
# Create envs.
|
|
||||||
if rank == 0:
|
|
||||||
env = gym.make(env_id)
|
|
||||||
if gym_monitor and logdir:
|
|
||||||
env = gym.wrappers.Monitor(env, os.path.join(logdir, 'gym_train'), force=True)
|
|
||||||
env = SimpleMonitor(env)
|
|
||||||
|
|
||||||
if evaluation:
|
# Create envs.
|
||||||
eval_env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
if gym_monitor and logdir:
|
env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), "%i.monitor.json"%rank))
|
||||||
eval_env = gym.wrappers.Monitor(eval_env, os.path.join(logdir, 'gym_eval'), force=True)
|
gym.logger.setLevel(logging.WARN)
|
||||||
eval_env = SimpleMonitor(eval_env)
|
|
||||||
else:
|
if evaluation and rank==0:
|
||||||
eval_env = None
|
eval_env = gym.make(env_id)
|
||||||
|
eval_env = bench.Monitor(eval_env, os.path.join(logger.get_dir(), 'gym_eval'))
|
||||||
|
env = bench.Monitor(env, None)
|
||||||
else:
|
else:
|
||||||
env = gym.make(env_id)
|
eval_env = None
|
||||||
if evaluation:
|
|
||||||
eval_env = gym.make(env_id)
|
|
||||||
else:
|
|
||||||
eval_env = None
|
|
||||||
|
|
||||||
# Parse noise_type
|
# Parse noise_type
|
||||||
action_noise = None
|
action_noise = None
|
||||||
@@ -103,22 +75,20 @@ def run(env_id, seed, noise_type, num_cpu, layer_norm, logdir, gym_monitor, eval
|
|||||||
env.close()
|
env.close()
|
||||||
if eval_env is not None:
|
if eval_env is not None:
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
Logger.CURRENT.close()
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.info('total runtime: {}s'.format(time.time() - start_time))
|
logger.info('total runtime: {}s'.format(time.time() - start_time))
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
parser.add_argument('--env-id', type=str, default='HalfCheetah-v1')
|
parser.add_argument('--env-id', type=str, default='HalfCheetah-v1')
|
||||||
boolean_flag(parser, 'render-eval', default=False)
|
boolean_flag(parser, 'render-eval', default=False)
|
||||||
boolean_flag(parser, 'layer-norm', default=True)
|
boolean_flag(parser, 'layer-norm', default=True)
|
||||||
boolean_flag(parser, 'render', default=False)
|
boolean_flag(parser, 'render', default=False)
|
||||||
parser.add_argument('--num-cpu', type=int, default=1)
|
|
||||||
boolean_flag(parser, 'normalize-returns', default=False)
|
boolean_flag(parser, 'normalize-returns', default=False)
|
||||||
boolean_flag(parser, 'normalize-observations', default=True)
|
boolean_flag(parser, 'normalize-observations', default=True)
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
parser.add_argument('--critic-l2-reg', type=float, default=1e-2)
|
parser.add_argument('--critic-l2-reg', type=float, default=1e-2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64) # per MPI worker
|
parser.add_argument('--batch-size', type=int, default=64) # per MPI worker
|
||||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||||
@@ -133,29 +103,11 @@ def parse_args():
|
|||||||
parser.add_argument('--nb-eval-steps', type=int, default=100) # per epoch cycle and MPI worker
|
parser.add_argument('--nb-eval-steps', type=int, default=100) # per epoch cycle and MPI worker
|
||||||
parser.add_argument('--nb-rollout-steps', type=int, default=100) # per epoch cycle and MPI worker
|
parser.add_argument('--nb-rollout-steps', type=int, default=100) # per epoch cycle and MPI worker
|
||||||
parser.add_argument('--noise-type', type=str, default='adaptive-param_0.2') # choices are adaptive-param_xx, ou_xx, normal_xx, none
|
parser.add_argument('--noise-type', type=str, default='adaptive-param_0.2') # choices are adaptive-param_xx, ou_xx, normal_xx, none
|
||||||
parser.add_argument('--logdir', type=str, default=None)
|
boolean_flag(parser, 'evaluation', default=False)
|
||||||
boolean_flag(parser, 'gym-monitor', default=False)
|
|
||||||
boolean_flag(parser, 'evaluation', default=True)
|
|
||||||
boolean_flag(parser, 'bind-to-core', default=False)
|
|
||||||
|
|
||||||
return vars(parser.parse_args())
|
return vars(parser.parse_args())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Figure out what logdir to use.
|
|
||||||
if args['logdir'] is None:
|
|
||||||
args['logdir'] = os.getenv('OPENAI_LOGDIR')
|
|
||||||
|
|
||||||
# Print and save arguments.
|
|
||||||
logger.info('Arguments:')
|
|
||||||
for key in sorted(args.keys()):
|
|
||||||
logger.info('{}: {}'.format(key, args[key]))
|
|
||||||
logger.info('')
|
|
||||||
if args['logdir']:
|
|
||||||
with open(os.path.join(args['logdir'], 'args.json'), 'w') as f:
|
|
||||||
json.dump(args, f)
|
|
||||||
|
|
||||||
# Run actual script.
|
# Run actual script.
|
||||||
run(**args)
|
run(**args)
|
||||||
|
@@ -14,7 +14,7 @@ from mpi4py import MPI
|
|||||||
|
|
||||||
|
|
||||||
def train(env, nb_epochs, nb_epoch_cycles, render_eval, reward_scale, render, param_noise, actor, critic,
|
def train(env, nb_epochs, nb_epoch_cycles, render_eval, reward_scale, render, param_noise, actor, critic,
|
||||||
normalize_returns, normalize_observations, critic_l2_reg, actor_lr, critic_lr, action_noise, logdir,
|
normalize_returns, normalize_observations, critic_l2_reg, actor_lr, critic_lr, action_noise,
|
||||||
popart, gamma, clip_norm, nb_train_steps, nb_rollout_steps, nb_eval_steps, batch_size, memory,
|
popart, gamma, clip_norm, nb_train_steps, nb_rollout_steps, nb_eval_steps, batch_size, memory,
|
||||||
tau=0.01, eval_env=None, param_noise_adaption_interval=50):
|
tau=0.01, eval_env=None, param_noise_adaption_interval=50):
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
@@ -178,7 +178,7 @@ def train(env, nb_epochs, nb_epoch_cycles, render_eval, reward_scale, render, pa
|
|||||||
logger.record_tabular(key, combined_stats[key])
|
logger.record_tabular(key, combined_stats[key])
|
||||||
logger.dump_tabular()
|
logger.dump_tabular()
|
||||||
logger.info('')
|
logger.info('')
|
||||||
|
logdir = logger.get_dir()
|
||||||
if rank == 0 and logdir:
|
if rank == 0 and logdir:
|
||||||
if hasattr(env, 'get_state'):
|
if hasattr(env, 'get_state'):
|
||||||
with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as f:
|
with open(os.path.join(logdir, 'env_state.pkl'), 'wb') as f:
|
||||||
|
@@ -40,7 +40,7 @@ class HumanOutputFormat(OutputFormat):
|
|||||||
def writekvs(self, kvs):
|
def writekvs(self, kvs):
|
||||||
# Create strings for printing
|
# Create strings for printing
|
||||||
key2str = {}
|
key2str = {}
|
||||||
for (key, val) in kvs.items():
|
for (key, val) in sorted(kvs.items()):
|
||||||
if isinstance(val, float):
|
if isinstance(val, float):
|
||||||
valstr = '%-8.3g' % (val,)
|
valstr = '%-8.3g' % (val,)
|
||||||
else:
|
else:
|
||||||
@@ -81,7 +81,7 @@ class JSONOutputFormat(OutputFormat):
|
|||||||
self.file = file
|
self.file = file
|
||||||
|
|
||||||
def writekvs(self, kvs):
|
def writekvs(self, kvs):
|
||||||
for k, v in kvs.items():
|
for k, v in sorted(kvs.items()):
|
||||||
if hasattr(v, 'dtype'):
|
if hasattr(v, 'dtype'):
|
||||||
v = v.tolist()
|
v = v.tolist()
|
||||||
kvs[k] = float(v)
|
kvs[k] = float(v)
|
||||||
@@ -274,11 +274,16 @@ def configure(dir=None, format_strs=None):
|
|||||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||||
log('Logging to %s'%dir)
|
log('Logging to %s'%dir)
|
||||||
|
|
||||||
|
if os.getenv('OPENAI_LOGDIR'):
|
||||||
|
# if OPENAI_LOGDIR is set, configure the logger on import
|
||||||
|
# this kind of nasty (unexpected to user), but I don't know how else to inject the logger
|
||||||
|
# to a script that's getting run in a subprocess
|
||||||
|
configure(dir=os.getenv('OPENAI_LOGDIR'))
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
Logger.CURRENT = Logger.DEFAULT
|
Logger.CURRENT = Logger.DEFAULT
|
||||||
log('Reset logger')
|
log('Reset logger')
|
||||||
|
|
||||||
|
|
||||||
# ================================================================
|
# ================================================================
|
||||||
|
|
||||||
def _demo():
|
def _demo():
|
||||||
|
7
baselines/ppo1/README.md
Normal file
7
baselines/ppo1/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# PPOSGD
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1707.06347
|
||||||
|
- Baselines blog post: https://blog.openai.com/openai-baselines-ppo/
|
||||||
|
- `python -m baselines.ppo.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options.
|
||||||
|
- `python -m baselines.ppo.run_mujoco` runs the algorithm for 1M frames on a Mujoco environment.
|
||||||
|
|
@@ -84,6 +84,7 @@ def learn(env, policy_func, *,
|
|||||||
gamma, lam, # advantage estimation
|
gamma, lam, # advantage estimation
|
||||||
max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0, # time constraint
|
max_timesteps=0, max_episodes=0, max_iters=0, max_seconds=0, # time constraint
|
||||||
callback=None, # you can do anything in the callback, since it takes locals(), globals()
|
callback=None, # you can do anything in the callback, since it takes locals(), globals()
|
||||||
|
adam_epsilon=1e-5,
|
||||||
schedule='constant' # annealing for stepsize parameters (epsilon and adam)
|
schedule='constant' # annealing for stepsize parameters (epsilon and adam)
|
||||||
):
|
):
|
||||||
# Setup losses and stuff
|
# Setup losses and stuff
|
||||||
@@ -111,17 +112,14 @@ def learn(env, policy_func, *,
|
|||||||
surr1 = ratio * atarg # surrogate from conservative policy iteration
|
surr1 = ratio * atarg # surrogate from conservative policy iteration
|
||||||
surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
|
surr2 = U.clip(ratio, 1.0 - clip_param, 1.0 + clip_param) * atarg #
|
||||||
pol_surr = - U.mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
|
pol_surr = - U.mean(tf.minimum(surr1, surr2)) # PPO's pessimistic surrogate (L^CLIP)
|
||||||
vfloss1 = tf.square(pi.vpred - ret)
|
vf_loss = U.mean(tf.square(pi.vpred - ret))
|
||||||
vpredclipped = oldpi.vpred + tf.clip_by_value(pi.vpred - oldpi.vpred, -clip_param, clip_param)
|
|
||||||
vfloss2 = tf.square(vpredclipped - ret)
|
|
||||||
vf_loss = .5 * U.mean(tf.maximum(vfloss1, vfloss2)) # we do the same clipping-based trust region for the value function
|
|
||||||
total_loss = pol_surr + pol_entpen + vf_loss
|
total_loss = pol_surr + pol_entpen + vf_loss
|
||||||
losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
|
losses = [pol_surr, pol_entpen, vf_loss, meankl, meanent]
|
||||||
loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]
|
loss_names = ["pol_surr", "pol_entpen", "vf_loss", "kl", "ent"]
|
||||||
|
|
||||||
var_list = pi.get_trainable_variables()
|
var_list = pi.get_trainable_variables()
|
||||||
lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
|
lossandgrad = U.function([ob, ac, atarg, ret, lrmult], losses + [U.flatgrad(total_loss, var_list)])
|
||||||
adam = MpiAdam(var_list)
|
adam = MpiAdam(var_list, epsilon=adam_epsilon)
|
||||||
|
|
||||||
assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
|
assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv)
|
||||||
for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
|
for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())])
|
@@ -1,12 +1,11 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
from baselines.common import set_global_seeds
|
from baselines.common import set_global_seeds
|
||||||
from baselines import bench
|
from baselines import bench
|
||||||
from baselines.common.mpi_fork import mpi_fork
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import gym, logging
|
import gym, logging
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
import sys
|
|
||||||
|
|
||||||
def wrap_train(env):
|
def wrap_train(env):
|
||||||
from baselines.common.atari_wrappers import (wrap_deepmind, FrameStack)
|
from baselines.common.atari_wrappers import (wrap_deepmind, FrameStack)
|
||||||
@@ -14,11 +13,9 @@ def wrap_train(env):
|
|||||||
env = FrameStack(env, 4)
|
env = FrameStack(env, 4)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def train(env_id, num_timesteps, seed, num_cpu):
|
def train(env_id, num_frames, seed):
|
||||||
from baselines.pposgd import pposgd_simple, cnn_policy
|
from baselines.ppo1 import pposgd_simple, cnn_policy
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
whoami = mpi_fork(num_cpu)
|
|
||||||
if whoami == "parent": return
|
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
sess = U.single_threaded_session()
|
sess = U.single_threaded_session()
|
||||||
sess.__enter__()
|
sess.__enter__()
|
||||||
@@ -28,12 +25,13 @@ def train(env_id, num_timesteps, seed, num_cpu):
|
|||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
|
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
|
||||||
return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
|
return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
|
||||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
|
env = bench.Monitor(env, logger.get_dir() and
|
||||||
|
osp.join(logger.get_dir(), "%i.monitor.json" % rank))
|
||||||
env.seed(workerseed)
|
env.seed(workerseed)
|
||||||
gym.logger.setLevel(logging.WARN)
|
gym.logger.setLevel(logging.WARN)
|
||||||
|
|
||||||
env = wrap_train(env)
|
env = wrap_train(env)
|
||||||
num_timesteps /= 4 # because we're wrapping the envs to do frame skip
|
num_timesteps = int(num_frames / 4 * 1.1)
|
||||||
env.seed(workerseed)
|
env.seed(workerseed)
|
||||||
|
|
||||||
pposgd_simple.learn(env, policy_fn,
|
pposgd_simple.learn(env, policy_fn,
|
||||||
@@ -47,7 +45,12 @@ def train(env_id, num_timesteps, seed, num_cpu):
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
train('PongNoFrameskip-v4', num_timesteps=40e6, seed=0, num_cpu=8)
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--env', help='environment ID', default='PongNoFrameskip-v4')
|
||||||
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args.env, num_frames=40e6, seed=args.seed)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
@@ -7,14 +7,15 @@ from baselines import logger
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
def train(env_id, num_timesteps, seed):
|
def train(env_id, num_timesteps, seed):
|
||||||
from baselines.pposgd import mlp_policy, pposgd_simple
|
from baselines.ppo1 import mlp_policy, pposgd_simple
|
||||||
U.make_session(num_cpu=1).__enter__()
|
U.make_session(num_cpu=1).__enter__()
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
def policy_fn(name, ob_space, ac_space):
|
def policy_fn(name, ob_space, ac_space):
|
||||||
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
|
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
|
||||||
hid_size=64, num_hid_layers=2)
|
hid_size=64, num_hid_layers=2)
|
||||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "monitor.json"))
|
env = bench.Monitor(env, logger.get_dir() and
|
||||||
|
osp.join(logger.get_dir(), "monitor.json"))
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
gym.logger.setLevel(logging.WARN)
|
gym.logger.setLevel(logging.WARN)
|
||||||
pposgd_simple.learn(env, policy_fn,
|
pposgd_simple.learn(env, policy_fn,
|
||||||
@@ -22,12 +23,17 @@ def train(env_id, num_timesteps, seed):
|
|||||||
timesteps_per_batch=2048,
|
timesteps_per_batch=2048,
|
||||||
clip_param=0.2, entcoeff=0.0,
|
clip_param=0.2, entcoeff=0.0,
|
||||||
optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
|
optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
|
||||||
gamma=0.99, lam=0.95,
|
gamma=0.99, lam=0.95, schedule='linear',
|
||||||
)
|
)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
train('Hopper-v1', num_timesteps=1e6, seed=0)
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--env', help='environment ID', default='Hopper-v1')
|
||||||
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args.env, num_timesteps=1e6, seed=args.seed)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
6
baselines/trpo_mpi/README.md
Normal file
6
baselines/trpo_mpi/README.md
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# trpo_mpi
|
||||||
|
|
||||||
|
- Original paper: https://arxiv.org/abs/1502.05477
|
||||||
|
- Baselines blog post https://blog.openai.com/openai-baselines-ppo/
|
||||||
|
- `python -m baselines.ppo1.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options.
|
||||||
|
- `python -m baselines.ppo1.run_mujoco` runs the algorithm for 1M timesteps on a Mujoco environment.
|
@@ -5,7 +5,6 @@ import os.path as osp
|
|||||||
import gym, logging
|
import gym, logging
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
from baselines import bench
|
from baselines import bench
|
||||||
from baselines.common.mpi_fork import mpi_fork
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
def wrap_train(env):
|
def wrap_train(env):
|
||||||
@@ -14,13 +13,10 @@ def wrap_train(env):
|
|||||||
env = FrameStack(env, 3)
|
env = FrameStack(env, 3)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def train(env_id, num_timesteps, seed, num_cpu):
|
def train(env_id, num_frames, seed):
|
||||||
from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
|
from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
|
||||||
from baselines.trpo_mpi import trpo_mpi
|
from baselines.trpo_mpi import trpo_mpi
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
whoami = mpi_fork(num_cpu)
|
|
||||||
if whoami == "parent":
|
|
||||||
return
|
|
||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
sess = U.single_threaded_session()
|
sess = U.single_threaded_session()
|
||||||
sess.__enter__()
|
sess.__enter__()
|
||||||
@@ -33,12 +29,13 @@ def train(env_id, num_timesteps, seed, num_cpu):
|
|||||||
env = gym.make(env_id)
|
env = gym.make(env_id)
|
||||||
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
|
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
|
||||||
return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
|
return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
|
||||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json"%rank))
|
env = bench.Monitor(env, logger.get_dir() and
|
||||||
|
osp.join(logger.get_dir(), "%i.monitor.json"%rank))
|
||||||
env.seed(workerseed)
|
env.seed(workerseed)
|
||||||
gym.logger.setLevel(logging.WARN)
|
gym.logger.setLevel(logging.WARN)
|
||||||
|
|
||||||
env = wrap_train(env)
|
env = wrap_train(env)
|
||||||
num_timesteps /= 4 # because we're wrapping the envs to do frame skip
|
num_timesteps = int(num_frames / 4 * 1.1)
|
||||||
env.seed(workerseed)
|
env.seed(workerseed)
|
||||||
|
|
||||||
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
|
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
|
||||||
@@ -46,7 +43,13 @@ def train(env_id, num_timesteps, seed, num_cpu):
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
train('PongNoFrameskip-v4', num_timesteps=40e6, seed=0, num_cpu=8)
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--env', help='environment ID', default='PongNoFrameskip-v4')
|
||||||
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args.env, num_frames=40e6, seed=args.seed)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@@ -7,17 +7,13 @@ import os.path as osp
|
|||||||
import gym
|
import gym
|
||||||
import logging
|
import logging
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
from baselines.pposgd.mlp_policy import MlpPolicy
|
from baselines.ppo1.mlp_policy import MlpPolicy
|
||||||
from baselines.common.mpi_fork import mpi_fork
|
from baselines.common.mpi_fork import mpi_fork
|
||||||
from baselines import bench
|
from baselines import bench
|
||||||
from baselines.trpo_mpi import trpo_mpi
|
from baselines.trpo_mpi import trpo_mpi
|
||||||
import sys
|
import sys
|
||||||
num_cpu=1
|
|
||||||
|
|
||||||
def train(env_id, num_timesteps, seed):
|
def train(env_id, num_timesteps, seed):
|
||||||
whoami = mpi_fork(num_cpu)
|
|
||||||
if whoami == "parent":
|
|
||||||
return
|
|
||||||
import baselines.common.tf_util as U
|
import baselines.common.tf_util as U
|
||||||
sess = U.single_threaded_session()
|
sess = U.single_threaded_session()
|
||||||
sess.__enter__()
|
sess.__enter__()
|
||||||
@@ -31,7 +27,8 @@ def train(env_id, num_timesteps, seed):
|
|||||||
def policy_fn(name, ob_space, ac_space):
|
def policy_fn(name, ob_space, ac_space):
|
||||||
return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
|
return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
|
||||||
hid_size=32, num_hid_layers=2)
|
hid_size=32, num_hid_layers=2)
|
||||||
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
|
env = bench.Monitor(env, logger.get_dir() and
|
||||||
|
osp.join(logger.get_dir(), "%i.monitor.json" % rank))
|
||||||
env.seed(workerseed)
|
env.seed(workerseed)
|
||||||
gym.logger.setLevel(logging.WARN)
|
gym.logger.setLevel(logging.WARN)
|
||||||
|
|
||||||
@@ -40,7 +37,13 @@ def train(env_id, num_timesteps, seed):
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
train('Hopper-v1', num_timesteps=1e6, seed=0)
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--env', help='environment ID', default='Hopper-v1')
|
||||||
|
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args.env, num_timesteps=1e6, seed=args.seed)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Reference in New Issue
Block a user