diff --git a/baselines/acer/acer.py b/baselines/acer/acer.py index 0ae0330..df4e0bf 100644 --- a/baselines/acer/acer.py +++ b/baselines/acer/acer.py @@ -75,8 +75,8 @@ class Model(object): train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape) with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE): - step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess) - train_model = policy(observ_placeholder=train_ob_placeholder, sess=sess) + step_model = policy(nbatch=nenvs, nsteps=1, observ_placeholder=step_ob_placeholder, sess=sess) + train_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess) params = find_trainable_variables("acer_model") @@ -94,7 +94,7 @@ class Model(object): return v with tf.variable_scope("acer_model", custom_getter=custom_getter, reuse=True): - polyak_model = policy(observ_placeholder=train_ob_placeholder, sess=sess) + polyak_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess) # Notation: (var) = batch variable, (var)s = seqeuence variable, (var)_i = variable index by action at step i