Files
baselines/baselines/common/tests/test_doc_examples.py

49 lines
1.3 KiB
Python

import pytest
try:
import mujoco_py
_mujoco_present = True
except BaseException:
mujoco_py = None
_mujoco_present = False
@pytest.mark.skipif(
not _mujoco_present,
reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library'
)
def test_lstm_example():
import tensorflow as tf
from baselines.common import policies, models, cmd_util
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
# create vectorized environment
venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)])
with tf.Session() as sess:
# build policy based on lstm network with 128 units
policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1)
# initialize tensorflow variables
sess.run(tf.global_variables_initializer())
# prepare environment variables
ob = venv.reset()
state = policy.initial_state
done = [False]
step_counter = 0
# run a single episode until the end (i.e. until done)
while True:
action, _, state, _ = policy.step(ob, S=state, M=done)
ob, reward, done, _ = venv.step(action)
step_counter += 1
if done:
break
assert step_counter > 5