diff --git a/baselines/common/models.py b/baselines/common/models.py index 0763095..8088e7d 100644 --- a/baselines/common/models.py +++ b/baselines/common/models.py @@ -138,7 +138,7 @@ def conv_only(convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], **conv_kwargs): ''' def network_fn(X): - out = X + X = tf.cast(X, tf.float32) / 255. with tf.variable_scope("convnet"): for num_outputs, kernel_size, stride in convs: out = layers.convolution2d(out,