adjust input shape.

This commit is contained in:
gyunt
2019-03-22 23:00:11 +09:00
parent 93c3f32a76
commit b766b6413e
2 changed files with 4 additions and 7 deletions

View File

@@ -47,7 +47,6 @@ class MicrobatchedModel(Model):
# Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list)
stats_vs = []
optional_td_map = self.train_model.feed_dict(**_kwargs)
for microbatch_idx in range(self.nmicrobatches):
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size)
@@ -63,11 +62,8 @@ class MicrobatchedModel(Model):
self.VALUE_PREV: values[_sli],
}
for key in optional_td_map:
if key.get_shape().ndims == 1:
td_map.update({key: optional_td_map[key][_sli]})
else:
raise NotImplementedError
sliced_kwargs = {key: _kwargs[key][_sli] for key in _kwargs}
td_map.update(self.train_model.feed_dict(**sliced_kwargs))
# Compute gradient on a microbatch (note that variables do not change here) ...
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)

View File

@@ -8,6 +8,7 @@ from baselines.common import tf_util
from baselines.common.distributions import make_pdtype
from baselines.common.input import observation_placeholder, encode_observation
from baselines.common.models import get_network_builder
from baselines.common.tf_util import adjust_shape
class PolicyWithValue(object):
@@ -83,7 +84,7 @@ class PolicyWithValue(object):
feed_dict = {}
for key in kwargs:
if key in self.step_input:
feed_dict[self.step_input[key]] = kwargs[key]
feed_dict[self.step_input[key]] = adjust_shape(self.step_input[key], kwargs[key])
return feed_dict
def step(self, **kwargs):