Files
Gymnasium/gym/benchmarks/registration.py
Greg Brockman 934b2acbb7 Add benchmark support (#338)
* Warn if seed doesn't return a list

* Add preliminary BenchmarkRun support

* Add experimental benchmark registration

* Flesh out interface

* Add preliminary BenchmarkRun support

* Warn if seed doesn't return a list

* Add experimental benchmark registration

* Flesh out interface

* Make benchmarkrun upload recursive

* Add evaluation episodes

* Add benchmark scoring

* Tweak reward locations

* Tweak scoring

* Clear default metadata in Wrapper

* Improve scoring

* Expose registry; fix test

* Add initial_reset_timestamp

* Add back algorithm; fix tests
2016-09-23 01:04:26 -07:00

66 lines
2.2 KiB
Python

# EXPERIMENTAL: all may be removed soon
import collections
import gym.envs
import logging
from gym import error
logger = logging.getLogger(__name__)
class Task(object):
def __init__(self, env_id, seeds, timesteps, reward_floor, reward_ceiling):
self.env_id = env_id
self.seeds = seeds
self.timesteps = timesteps
self.reward_floor = reward_floor
self.reward_ceiling = reward_ceiling
class Benchmark(object):
def __init__(self, id, scorer, task_groups, description=None):
self.id = id
self.scorer = scorer
self.description = description
task_map = {}
for env_id, tasks in task_groups.items():
task_map[env_id] = []
for task in tasks:
task_map[env_id].append(Task(
env_id=env_id,
seeds=task['seeds'],
timesteps=task['timesteps'],
reward_floor=task.get('reward_floor', 0),
reward_ceiling=task.get('reward_ceiling', 100),
))
self.task_groups = task_map
def task_spec(self, env_id):
try:
return self.task_groups[env_id]
except KeyError:
raise error.Unregistered('No task with env_id {} registered for benchmark {}', env_id, self.id)
def score_evaluation(self, env_id, episode_lengths, episode_rewards, episode_types, timestamps, initial_reset_timestamp):
return self.scorer.score_evaluation(self, env_id, episode_lengths, episode_rewards, episode_types, timestamps, initial_reset_timestamp)
def score_benchmark(self, score_map):
return self.scorer.score_benchmark(self, score_map)
class Registry(object):
def __init__(self):
self.benchmarks = collections.OrderedDict()
def register_benchmark(self, id, **kwargs):
self.benchmarks[id] = Benchmark(id=id, **kwargs)
def benchmark_spec(self, id):
try:
return self.benchmarks[id]
except KeyError:
raise error.UnregisteredBenchmark('No registered benchmark with id: {}'.format(id))
registry = Registry()
register_benchmark = registry.register_benchmark
benchmark_spec = registry.benchmark_spec