mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
51 lines
1.2 KiB
Python
51 lines
1.2 KiB
Python
"""Test lambda reward wrapper."""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.experimental.wrappers import LambdaRewardV0
|
|
from tests.experimental.wrappers.utils import DISCRETE_ACTION, ENV_ID, NUM_ENVS, SEED
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("reward_fn", "expected_reward"),
|
|
[(lambda r: 2 * r + 1, 3)],
|
|
)
|
|
def test_lambda_reward(reward_fn, expected_reward):
|
|
"""Test lambda reward.
|
|
|
|
Tests if function is correctly applied
|
|
to reward.
|
|
"""
|
|
env = gym.make(ENV_ID)
|
|
env = LambdaRewardV0(env, reward_fn)
|
|
env.reset(seed=SEED)
|
|
|
|
_, rew, _, _, _ = env.step(DISCRETE_ACTION)
|
|
|
|
assert rew == expected_reward
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
(
|
|
"reward_fn",
|
|
"expected_reward",
|
|
),
|
|
[(lambda r: 2 * r + 1, 3)],
|
|
)
|
|
def test_lambda_reward_within_vector(reward_fn, expected_reward):
|
|
"""Test lambda reward in vectorized environment.
|
|
|
|
Tests if function is correctly applied
|
|
to reward in a vectorized environment.
|
|
"""
|
|
actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)]
|
|
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
|
|
env = LambdaRewardV0(env, reward_fn)
|
|
env.reset(seed=SEED)
|
|
|
|
_, rew, _, _, _ = env.step(actions)
|
|
|
|
assert np.alltrue(rew == expected_reward)
|