diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index b8b2a87..0d1df89 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -52,7 +52,7 @@ def argmax(x, axis=None): def switch(condition, then_expression, else_expression): - '''Switches between two operations depending on a scalar value (int or bool). + """Switches between two operations depending on a scalar value (int or bool). Note that both `then_expression` and `else_expression` should be symbolic tensors of the *same shape*. @@ -60,7 +60,7 @@ def switch(condition, then_expression, else_expression): condition: scalar tensor. then_expression: TensorFlow operation. else_expression: TensorFlow operation. - ''' + """ x_shape = copy.copy(then_expression.get_shape()) x = tf.cond(tf.cast(condition, 'bool'), lambda: then_expression, @@ -362,9 +362,9 @@ def dropout(x, pkeep, phase=None, mask=None): def function(inputs, outputs, updates=None, givens=None): - """Just like Theano function. Take a bunch of tensorflow placeholders and expersions + """Just like Theano function. Take a bunch of tensorflow placeholders and expressions computed based on those placeholders and produces f(inputs) -> outputs. Function f takes - values to be feed to the inputs placeholders and produces the values of the experessions + values to be fed to the input's placeholders and produces the values of the expressions in outputs. Input values can be passed in the same order as inputs or can be provided as kwargs based