improves for 'play' option.
This commit is contained in:
@@ -129,8 +129,13 @@ class Model(object):
|
|||||||
if MPI is not None:
|
if MPI is not None:
|
||||||
sync_from_root(sess, global_variables) # pylint: disable=E1101
|
sync_from_root(sess, global_variables) # pylint: disable=E1101
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step_as_dict(self, **kwargs):
|
||||||
return self.act_model.step(*args, **kwargs)
|
return self.act_model.step(**kwargs)
|
||||||
|
|
||||||
|
def step(self, observations, **kwargs):
|
||||||
|
kwargs.update({'observations': observations})
|
||||||
|
transition = self.act_model.step(**kwargs)
|
||||||
|
return transition['actions'], transition['values'], transition['states'], transition['neglogpacs']
|
||||||
|
|
||||||
def train(self,
|
def train(self,
|
||||||
lr,
|
lr,
|
||||||
|
@@ -54,7 +54,7 @@ class Runner(AbstractEnvRunner):
|
|||||||
transition = {}
|
transition = {}
|
||||||
transition['obs'] = self.obs.copy()
|
transition['obs'] = self.obs.copy()
|
||||||
transition['dones'] = dones
|
transition['dones'] = dones
|
||||||
transition.update(self.model.step(observations=self.obs, **prev_state))
|
transition.update(self.model.step_as_dict(observations=self.obs, **prev_state))
|
||||||
transition['values'] = transition['values']
|
transition['values'] = transition['values']
|
||||||
|
|
||||||
# Take actions in env and look the results
|
# Take actions in env and look the results
|
||||||
@@ -84,7 +84,7 @@ class Runner(AbstractEnvRunner):
|
|||||||
dtype = data_type[key] if key in data_type else np.float
|
dtype = data_type[key] if key in data_type else np.float
|
||||||
minibatch[key] = np.asarray(minibatch[key], dtype=dtype)
|
minibatch[key] = np.asarray(minibatch[key], dtype=dtype)
|
||||||
|
|
||||||
last_values = self.model.step(observations=self.obs, **self.states)['values']
|
last_values = self.model.step_as_dict(observations=self.obs, **self.states)['values']
|
||||||
|
|
||||||
# Calculate returns and advantages.
|
# Calculate returns and advantages.
|
||||||
minibatch['advs'], minibatch['returns'] = \
|
minibatch['advs'], minibatch['returns'] = \
|
||||||
|
Reference in New Issue
Block a user