49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
import pytest
|
|
try:
|
|
import mujoco_py
|
|
_mujoco_present = True
|
|
except BaseException:
|
|
mujoco_py = None
|
|
_mujoco_present = False
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not _mujoco_present,
|
|
reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library'
|
|
)
|
|
def test_lstm_example():
|
|
import tensorflow as tf
|
|
from baselines.common import policies, models, cmd_util
|
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
|
|
|
# create vectorized environment
|
|
venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)])
|
|
|
|
with tf.Session() as sess:
|
|
# build policy based on lstm network with 128 units
|
|
policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1)
|
|
|
|
# initialize tensorflow variables
|
|
sess.run(tf.global_variables_initializer())
|
|
|
|
# prepare environment variables
|
|
ob = venv.reset()
|
|
state = policy.initial_state
|
|
done = [False]
|
|
step_counter = 0
|
|
|
|
# run a single episode until the end (i.e. until done)
|
|
while True:
|
|
action, _, state, _ = policy.step(ob, S=state, M=done)
|
|
ob, reward, done, _ = venv.step(action)
|
|
step_counter += 1
|
|
if done:
|
|
break
|
|
|
|
|
|
assert step_counter > 5
|
|
|
|
|
|
|
|
|