replace ppo2.step()
with original interface.
This commit is contained in:
@@ -140,11 +140,13 @@ class Model(object):
|
|||||||
def step_as_dict(self, **kwargs):
|
def step_as_dict(self, **kwargs):
|
||||||
return self.act_model.step(**kwargs)
|
return self.act_model.step(**kwargs)
|
||||||
|
|
||||||
def step(self, observation, done, **kwargs):
|
def step(self, obs, M=None, S=None, **kwargs):
|
||||||
kwargs.update({'observations': observation})
|
kwargs.update({'observations': obs})
|
||||||
kwargs.update({'dones': done})
|
if M is not None and S is not None:
|
||||||
|
kwargs.update({'dones': M})
|
||||||
|
kwargs.update({'states': S})
|
||||||
transition = self.act_model.step(**kwargs)
|
transition = self.act_model.step(**kwargs)
|
||||||
states = transition['states'] if 'states' in transition else None
|
states = transition['next_states'] if 'next_states' in transition else None
|
||||||
return transition['actions'], transition['values'], states, transition['neglogpacs']
|
return transition['actions'], transition['values'], states, transition['neglogpacs']
|
||||||
|
|
||||||
def train(self,
|
def train(self,
|
||||||
|
@@ -39,7 +39,7 @@ class Runner(AbstractEnvRunner):
|
|||||||
"neglogpacs": np.float32,
|
"neglogpacs": np.float32,
|
||||||
}
|
}
|
||||||
|
|
||||||
prev_transition = {'next_states': self.model.initial_state}
|
prev_transition = {'next_states': self.model.initial_state} if self.model.initial_state is not None else {}
|
||||||
epinfos = []
|
epinfos = []
|
||||||
|
|
||||||
# For n in range number of steps
|
# For n in range number of steps
|
||||||
|
Reference in New Issue
Block a user