adjust input shape.
This commit is contained in:
@@ -47,7 +47,6 @@ class MicrobatchedModel(Model):
|
||||
|
||||
# Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list)
|
||||
stats_vs = []
|
||||
optional_td_map = self.train_model.feed_dict(**_kwargs)
|
||||
|
||||
for microbatch_idx in range(self.nmicrobatches):
|
||||
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size)
|
||||
@@ -63,11 +62,8 @@ class MicrobatchedModel(Model):
|
||||
self.VALUE_PREV: values[_sli],
|
||||
}
|
||||
|
||||
for key in optional_td_map:
|
||||
if key.get_shape().ndims == 1:
|
||||
td_map.update({key: optional_td_map[key][_sli]})
|
||||
else:
|
||||
raise NotImplementedError
|
||||
sliced_kwargs = {key: _kwargs[key][_sli] for key in _kwargs}
|
||||
td_map.update(self.train_model.feed_dict(**sliced_kwargs))
|
||||
|
||||
# Compute gradient on a microbatch (note that variables do not change here) ...
|
||||
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)
|
||||
|
@@ -8,6 +8,7 @@ from baselines.common import tf_util
|
||||
from baselines.common.distributions import make_pdtype
|
||||
from baselines.common.input import observation_placeholder, encode_observation
|
||||
from baselines.common.models import get_network_builder
|
||||
from baselines.common.tf_util import adjust_shape
|
||||
|
||||
|
||||
class PolicyWithValue(object):
|
||||
@@ -83,7 +84,7 @@ class PolicyWithValue(object):
|
||||
feed_dict = {}
|
||||
for key in kwargs:
|
||||
if key in self.step_input:
|
||||
feed_dict[self.step_input[key]] = kwargs[key]
|
||||
feed_dict[self.step_input[key]] = adjust_shape(self.step_input[key], kwargs[key])
|
||||
return feed_dict
|
||||
|
||||
def step(self, **kwargs):
|
||||
|
Reference in New Issue
Block a user