revert baselines/run.py.

This commit is contained in:
gyunt
2019-04-08 21:46:44 +09:00
parent 36aadd6a4b
commit 93232a24e1

View File

@@ -6,7 +6,6 @@ import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np
import inspect
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
@@ -219,20 +218,15 @@ def main(args):
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
use_external_memory_management = True if hasattr(model, 'initial_state') else None
step_func_args = inspect.getfullargspec(model.step).args
done = np.zeros((1,))
dones = np.zeros((1,))
episode_rew = 0
while True:
kwargs = {}
if use_external_memory_management:
kwargs['S'] = state
kwargs['M'] = done
elif 'done' in step_func_args:
kwargs['done'] = done
if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones)
else:
actions, _, _, _ = model.step(obs)
actions, _, state, _ = model.step(obs, **kwargs)
obs, rew, done, _ = env.step(actions)
episode_rew += rew[0] if isinstance(env, VecEnv) else rew
env.render()