Files
Gymnasium/tests/experimental/wrappers/test_lambda_rewards.py
2022-12-10 22:04:14 +00:00

51 lines
1.2 KiB
Python

"""Test lambda reward wrapper."""
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import LambdaRewardV0
from tests.experimental.wrappers.utils import DISCRETE_ACTION, ENV_ID, NUM_ENVS, SEED
@pytest.mark.parametrize(
("reward_fn", "expected_reward"),
[(lambda r: 2 * r + 1, 3)],
)
def test_lambda_reward(reward_fn, expected_reward):
"""Test lambda reward.
Tests if function is correctly applied
to reward.
"""
env = gym.make(ENV_ID)
env = LambdaRewardV0(env, reward_fn)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
assert rew == expected_reward
@pytest.mark.parametrize(
(
"reward_fn",
"expected_reward",
),
[(lambda r: 2 * r + 1, 3)],
)
def test_lambda_reward_within_vector(reward_fn, expected_reward):
"""Test lambda reward in vectorized environment.
Tests if function is correctly applied
to reward in a vectorized environment.
"""
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
env = LambdaRewardV0(env, reward_fn)
env.reset(seed=SEED)
_, rew, _, _, _ = env.step(actions)
assert np.alltrue(rew == expected_reward)