Files
baselines/baselines/ddpg/models.py
Matthias Plappert 882251878f Parameter space noise for DQN and DDPG (#75)
* Export param noise

* Update documentation

* Final finishing touches
2017-07-27 08:10:59 -07:00

78 lines
2.4 KiB
Python

import tensorflow as tf
import tensorflow.contrib as tc
class Model(object):
def __init__(self, name):
self.name = name
@property
def vars(self):
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
@property
def trainable_vars(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
@property
def perturbable_vars(self):
return [var for var in self.trainable_vars if 'LayerNorm' not in var.name]
class Actor(Model):
def __init__(self, nb_actions, name='actor', layer_norm=True):
super(Actor, self).__init__(name=name)
self.nb_actions = nb_actions
self.layer_norm = layer_norm
def __call__(self, obs, reuse=False):
with tf.variable_scope(self.name) as scope:
if reuse:
scope.reuse_variables()
x = obs
x = tf.layers.dense(x, 64)
if self.layer_norm:
x = tc.layers.layer_norm(x, center=True, scale=True)
x = tf.nn.relu(x)
x = tf.layers.dense(x, 64)
if self.layer_norm:
x = tc.layers.layer_norm(x, center=True, scale=True)
x = tf.nn.relu(x)
x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
x = tf.nn.tanh(x)
return x
class Critic(Model):
def __init__(self, name='critic', layer_norm=True):
super(Critic, self).__init__(name=name)
self.layer_norm = layer_norm
def __call__(self, obs, action, reuse=False):
with tf.variable_scope(self.name) as scope:
if reuse:
scope.reuse_variables()
x = obs
x = tf.layers.dense(x, 64)
if self.layer_norm:
x = tc.layers.layer_norm(x, center=True, scale=True)
x = tf.nn.relu(x)
x = tf.concat([x, action], axis=-1)
x = tf.layers.dense(x, 64)
if self.layer_norm:
x = tc.layers.layer_norm(x, center=True, scale=True)
x = tf.nn.relu(x)
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
return x
@property
def output_vars(self):
output_vars = [var for var in self.trainable_vars if 'output' in var.name]
return output_vars