mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 14:26:33 +00:00
[FrozenLake] Add seed in random map generation (#139)
This commit is contained in:
@@ -9,6 +9,7 @@ import gymnasium as gym
|
||||
from gymnasium import Env, spaces, utils
|
||||
from gymnasium.envs.toy_text.utils import categorical_sample
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
LEFT = 0
|
||||
DOWN = 1
|
||||
@@ -51,12 +52,15 @@ def is_valid(board: List[List[str]], max_size: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
|
||||
def generate_random_map(
|
||||
size: int = 8, p: float = 0.8, seed: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""Generates a random valid map (one that has a path from start to goal)
|
||||
|
||||
Args:
|
||||
size: size of each side of the grid
|
||||
p: probability that a tile is frozen
|
||||
seed: optional seed to ensure the generation of reproducible maps
|
||||
|
||||
Returns:
|
||||
A random valid map
|
||||
@@ -64,9 +68,11 @@ def generate_random_map(size: int = 8, p: float = 0.8) -> List[str]:
|
||||
valid = False
|
||||
board = [] # initialize to make pyright happy
|
||||
|
||||
np_random, _ = seeding.np_random(seed)
|
||||
|
||||
while not valid:
|
||||
p = min(1, p)
|
||||
board = np.random.choice(["F", "H"], (size, size), p=[p, 1 - p])
|
||||
board = np_random.choice(["F", "H"], (size, size), p=[p, 1 - p])
|
||||
board[0][0] = "S"
|
||||
board[-1][-1] = "G"
|
||||
valid = is_valid(board, size)
|
||||
|
@@ -126,6 +126,16 @@ def test_frozenlake_dfs_map_generation(map_size: int):
|
||||
raise AssertionError("No path through the frozenlake was found.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("map_size, seed", [(5, 123), (10, 42), (16, 987)])
|
||||
def test_frozenlake_map_generation_with_seed(map_size: int, seed: int):
|
||||
map1 = generate_random_map(size=map_size, seed=seed)
|
||||
map2 = generate_random_map(size=map_size, seed=seed)
|
||||
assert map1 == map2
|
||||
map1 = generate_random_map(size=map_size, seed=seed)
|
||||
map2 = generate_random_map(size=map_size, seed=seed + 1)
|
||||
assert map1 != map2
|
||||
|
||||
|
||||
def test_taxi_action_mask():
|
||||
env = TaxiEnv()
|
||||
|
||||
|
Reference in New Issue
Block a user