diff --git a/baselines/ppo2/microbatched_model.py b/baselines/ppo2/microbatched_model.py index 28c05aa..eb71b63 100644 --- a/baselines/ppo2/microbatched_model.py +++ b/baselines/ppo2/microbatched_model.py @@ -47,7 +47,6 @@ class MicrobatchedModel(Model): # Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list) stats_vs = [] - optional_td_map = self.train_model.feed_dict(**_kwargs) for microbatch_idx in range(self.nmicrobatches): _sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size) @@ -63,11 +62,8 @@ class MicrobatchedModel(Model): self.VALUE_PREV: values[_sli], } - for key in optional_td_map: - if key.get_shape().ndims == 1: - td_map.update({key: optional_td_map[key][_sli]}) - else: - raise NotImplementedError + sliced_kwargs = {key: _kwargs[key][_sli] for key in _kwargs} + td_map.update(self.train_model.feed_dict(**sliced_kwargs)) # Compute gradient on a microbatch (note that variables do not change here) ... grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map) diff --git a/baselines/ppo2/policies.py b/baselines/ppo2/policies.py index b610dc0..fb685d1 100644 --- a/baselines/ppo2/policies.py +++ b/baselines/ppo2/policies.py @@ -8,6 +8,7 @@ from baselines.common import tf_util from baselines.common.distributions import make_pdtype from baselines.common.input import observation_placeholder, encode_observation from baselines.common.models import get_network_builder +from baselines.common.tf_util import adjust_shape class PolicyWithValue(object): @@ -83,7 +84,7 @@ class PolicyWithValue(object): feed_dict = {} for key in kwargs: if key in self.step_input: - feed_dict[self.step_input[key]] = kwargs[key] + feed_dict[self.step_input[key]] = adjust_shape(self.step_input[key], kwargs[key]) return feed_dict def step(self, **kwargs):