support the play.

This commit is contained in:
gyunt
2019-03-23 03:30:44 +09:00
parent a9c2b79730
commit f996ffb52d

View File

@@ -5,6 +5,7 @@ import gym
from collections import defaultdict from collections import defaultdict
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import inspect
from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv from baselines.common.vec_env import VecFrameStack, VecNormalize, VecEnv
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
@@ -218,15 +219,20 @@ def main(args):
obs = env.reset() obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,)) use_external_memory_management = True if hasattr(model, 'initial_state') else None
step_func_args = inspect.getfullargspec(model.step).args
done = np.zeros((1,))
episode_rew = 0 episode_rew = 0
while True: while True:
if state is not None: kwargs = {}
actions, _, state, _ = model.step(obs,S=state, M=dones) if use_external_memory_management:
else: kwargs['S'] = state
actions, _, _, _ = model.step(obs) kwargs['M'] = done
elif 'done' in step_func_args:
kwargs['done'] = done
actions, _, state, _ = model.step(obs, **kwargs)
obs, rew, done, _ = env.step(actions) obs, rew, done, _ = env.step(actions)
episode_rew += rew[0] if isinstance(env, VecEnv) else rew episode_rew += rew[0] if isinstance(env, VecEnv) else rew
env.render() env.render()