mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 21:42:02 +00:00
66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
|
|
# EXPERIMENTAL: all may be removed soon
|
||
|
|
|
||
|
|
import collections
|
||
|
|
import gym.envs
|
||
|
|
import logging
|
||
|
|
|
||
|
|
from gym import error
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
class Task(object):
|
||
|
|
def __init__(self, env_id, seeds, timesteps, reward_floor, reward_ceiling):
|
||
|
|
self.env_id = env_id
|
||
|
|
self.seeds = seeds
|
||
|
|
self.timesteps = timesteps
|
||
|
|
self.reward_floor = reward_floor
|
||
|
|
self.reward_ceiling = reward_ceiling
|
||
|
|
|
||
|
|
class Benchmark(object):
|
||
|
|
def __init__(self, id, scorer, task_groups, description=None):
|
||
|
|
self.id = id
|
||
|
|
self.scorer = scorer
|
||
|
|
self.description = description
|
||
|
|
|
||
|
|
task_map = {}
|
||
|
|
for env_id, tasks in task_groups.items():
|
||
|
|
task_map[env_id] = []
|
||
|
|
for task in tasks:
|
||
|
|
task_map[env_id].append(Task(
|
||
|
|
env_id=env_id,
|
||
|
|
seeds=task['seeds'],
|
||
|
|
timesteps=task['timesteps'],
|
||
|
|
reward_floor=task.get('reward_floor', 0),
|
||
|
|
reward_ceiling=task.get('reward_ceiling', 100),
|
||
|
|
))
|
||
|
|
self.task_groups = task_map
|
||
|
|
|
||
|
|
def task_spec(self, env_id):
|
||
|
|
try:
|
||
|
|
return self.task_groups[env_id]
|
||
|
|
except KeyError:
|
||
|
|
raise error.Unregistered('No task with env_id {} registered for benchmark {}', env_id, self.id)
|
||
|
|
|
||
|
|
def score_evaluation(self, env_id, episode_lengths, episode_rewards, episode_types, timestamps, initial_reset_timestamp):
|
||
|
|
return self.scorer.score_evaluation(self, env_id, episode_lengths, episode_rewards, episode_types, timestamps, initial_reset_timestamp)
|
||
|
|
|
||
|
|
def score_benchmark(self, score_map):
|
||
|
|
return self.scorer.score_benchmark(self, score_map)
|
||
|
|
|
||
|
|
class Registry(object):
|
||
|
|
def __init__(self):
|
||
|
|
self.benchmarks = collections.OrderedDict()
|
||
|
|
|
||
|
|
def register_benchmark(self, id, **kwargs):
|
||
|
|
self.benchmarks[id] = Benchmark(id=id, **kwargs)
|
||
|
|
|
||
|
|
def benchmark_spec(self, id):
|
||
|
|
try:
|
||
|
|
return self.benchmarks[id]
|
||
|
|
except KeyError:
|
||
|
|
raise error.UnregisteredBenchmark('No registered benchmark with id: {}'.format(id))
|
||
|
|
|
||
|
|
registry = Registry()
|
||
|
|
register_benchmark = registry.register_benchmark
|
||
|
|
benchmark_spec = registry.benchmark_spec
|