replace ppo2.step() with original interface.

This commit is contained in:
gyunt
2019-04-08 22:40:08 +09:00
parent 93232a24e1
commit 354b1bda41
2 changed files with 7 additions and 5 deletions

View File

@@ -140,11 +140,13 @@ class Model(object):
def step_as_dict(self, **kwargs):
return self.act_model.step(**kwargs)
def step(self, observation, done, **kwargs):
kwargs.update({'observations': observation})
kwargs.update({'dones': done})
def step(self, obs, M=None, S=None, **kwargs):
kwargs.update({'observations': obs})
if M is not None and S is not None:
kwargs.update({'dones': M})
kwargs.update({'states': S})
transition = self.act_model.step(**kwargs)
states = transition['states'] if 'states' in transition else None
states = transition['next_states'] if 'next_states' in transition else None
return transition['actions'], transition['values'], states, transition['neglogpacs']
def train(self,

View File

@@ -39,7 +39,7 @@ class Runner(AbstractEnvRunner):
"neglogpacs": np.float32,
}
prev_transition = {'next_states': self.model.initial_state}
prev_transition = {'next_states': self.model.initial_state} if self.model.initial_state is not None else {}
epinfos = []
# For n in range number of steps