replace ppo2.step()
with original interface.
This commit is contained in:
@@ -140,11 +140,13 @@ class Model(object):
|
||||
def step_as_dict(self, **kwargs):
|
||||
return self.act_model.step(**kwargs)
|
||||
|
||||
def step(self, observation, done, **kwargs):
|
||||
kwargs.update({'observations': observation})
|
||||
kwargs.update({'dones': done})
|
||||
def step(self, obs, M=None, S=None, **kwargs):
|
||||
kwargs.update({'observations': obs})
|
||||
if M is not None and S is not None:
|
||||
kwargs.update({'dones': M})
|
||||
kwargs.update({'states': S})
|
||||
transition = self.act_model.step(**kwargs)
|
||||
states = transition['states'] if 'states' in transition else None
|
||||
states = transition['next_states'] if 'next_states' in transition else None
|
||||
return transition['actions'], transition['values'], states, transition['neglogpacs']
|
||||
|
||||
def train(self,
|
||||
|
@@ -39,7 +39,7 @@ class Runner(AbstractEnvRunner):
|
||||
"neglogpacs": np.float32,
|
||||
}
|
||||
|
||||
prev_transition = {'next_states': self.model.initial_state}
|
||||
prev_transition = {'next_states': self.model.initial_state} if self.model.initial_state is not None else {}
|
||||
epinfos = []
|
||||
|
||||
# For n in range number of steps
|
||||
|
Reference in New Issue
Block a user