revert baselines/run.py.
This commit is contained in:
@@ -6,7 +6,6 @@ import gym
|
||||
from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import inspect
|
||||
|
||||
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
|
||||
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||
@@ -219,20 +218,15 @@ def main(args):
|
||||
obs = env.reset()
|
||||
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
use_external_memory_management = True if hasattr(model, 'initial_state') else None
|
||||
step_func_args = inspect.getfullargspec(model.step).args
|
||||
done = np.zeros((1,))
|
||||
dones = np.zeros((1,))
|
||||
|
||||
episode_rew = 0
|
||||
while True:
|
||||
kwargs = {}
|
||||
if use_external_memory_management:
|
||||
kwargs['S'] = state
|
||||
kwargs['M'] = done
|
||||
elif 'done' in step_func_args:
|
||||
kwargs['done'] = done
|
||||
if state is not None:
|
||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||
else:
|
||||
actions, _, _, _ = model.step(obs)
|
||||
|
||||
actions, _, state, _ = model.step(obs, **kwargs)
|
||||
obs, rew, done, _ = env.step(actions)
|
||||
episode_rew += rew[0] if isinstance(env, VecEnv) else rew
|
||||
env.render()
|
||||
|
Reference in New Issue
Block a user