Merge branch 'master' of github.com:openai/baselines into internal
This commit is contained in:
17
baselines/ddpg/test_smoke.py
Normal file
17
baselines/ddpg/test_smoke.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from baselines.run import main as M
|
||||||
|
|
||||||
|
def _run(argstr):
|
||||||
|
M(('--alg=ddpg --env=Pendulum-v0 --num_timesteps=0 ' + argstr).split(' '))
|
||||||
|
|
||||||
|
def test_popart():
|
||||||
|
_run('--normalize_returns=True --popart=True')
|
||||||
|
|
||||||
|
def test_noise_normal():
|
||||||
|
_run('--noise_type=normal_0.1')
|
||||||
|
|
||||||
|
def test_noise_ou():
|
||||||
|
_run('--noise_type=ou_0.1')
|
||||||
|
|
||||||
|
def test_noise_adaptive():
|
||||||
|
_run('--noise_type=adaptive-param_0.2,normal_0.1')
|
||||||
|
|
@@ -367,8 +367,6 @@ class DDPG(object):
|
|||||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||||
|
|
||||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
|
||||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
|
||||||
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
|
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
|
||||||
pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
|
pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
|
||||||
assert len(self._vars('main/Q')) == len(Q_grads_tf)
|
assert len(self._vars('main/Q')) == len(Q_grads_tf)
|
||||||
|
@@ -181,11 +181,11 @@ def parse_cmdline_kwargs(args):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(args):
|
||||||
# configure logger, disable logging in child MPI processes (with rank > 0)
|
# configure logger, disable logging in child MPI processes (with rank > 0)
|
||||||
|
|
||||||
arg_parser = common_arg_parser()
|
arg_parser = common_arg_parser()
|
||||||
args, unknown_args = arg_parser.parse_known_args()
|
args, unknown_args = arg_parser.parse_known_args(args)
|
||||||
extra_args = parse_cmdline_kwargs(unknown_args)
|
extra_args = parse_cmdline_kwargs(unknown_args)
|
||||||
|
|
||||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
@@ -220,5 +220,7 @@ def main():
|
|||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main(sys.argv)
|
||||||
|
Reference in New Issue
Block a user