diff --git a/baselines/acer/policies.py b/baselines/acer/policies.py index 01ace6d..627c400 100644 --- a/baselines/acer/policies.py +++ b/baselines/acer/policies.py @@ -59,9 +59,9 @@ class AcerLstmPolicy(object): h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) h5 = seq_to_batch(h5) - pi_logits = fc(h5, 'pi', nact, act=lambda x: x, init_scale=0.01) + pi_logits = fc(h5, 'pi', nact, init_scale=0.01) pi = tf.nn.softmax(pi_logits) - q = fc(h5, 'q', nact, act=lambda x: x) + q = fc(h5, 'q', nact) a = sample(pi_logits) # could change this to use self.pi instead self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)