adjust input shape.

This commit is contained in:
gyunt
2019-03-22 23:00:11 +09:00
parent 93c3f32a76
commit b766b6413e
2 changed files with 4 additions and 7 deletions

View File

@@ -47,7 +47,6 @@ class MicrobatchedModel(Model):
# Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list) # Initialize empty list for per-microbatch stats like pg_loss, vf_loss, entropy, approxkl (whatever is in self.stats_list)
stats_vs = [] stats_vs = []
optional_td_map = self.train_model.feed_dict(**_kwargs)
for microbatch_idx in range(self.nmicrobatches): for microbatch_idx in range(self.nmicrobatches):
_sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size) _sli = range(microbatch_idx * self.microbatch_size, (microbatch_idx + 1) * self.microbatch_size)
@@ -63,11 +62,8 @@ class MicrobatchedModel(Model):
self.VALUE_PREV: values[_sli], self.VALUE_PREV: values[_sli],
} }
for key in optional_td_map: sliced_kwargs = {key: _kwargs[key][_sli] for key in _kwargs}
if key.get_shape().ndims == 1: td_map.update(self.train_model.feed_dict(**sliced_kwargs))
td_map.update({key: optional_td_map[key][_sli]})
else:
raise NotImplementedError
# Compute gradient on a microbatch (note that variables do not change here) ... # Compute gradient on a microbatch (note that variables do not change here) ...
grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map) grad_v, stats_v = self.sess.run([self.grads, self.stats_list], td_map)

View File

@@ -8,6 +8,7 @@ from baselines.common import tf_util
from baselines.common.distributions import make_pdtype from baselines.common.distributions import make_pdtype
from baselines.common.input import observation_placeholder, encode_observation from baselines.common.input import observation_placeholder, encode_observation
from baselines.common.models import get_network_builder from baselines.common.models import get_network_builder
from baselines.common.tf_util import adjust_shape
class PolicyWithValue(object): class PolicyWithValue(object):
@@ -83,7 +84,7 @@ class PolicyWithValue(object):
feed_dict = {} feed_dict = {}
for key in kwargs: for key in kwargs:
if key in self.step_input: if key in self.step_input:
feed_dict[self.step_input[key]] = kwargs[key] feed_dict[self.step_input[key]] = adjust_shape(self.step_input[key], kwargs[key])
return feed_dict return feed_dict
def step(self, **kwargs): def step(self, **kwargs):