Don't check exact seeds for now

This commit is contained in:
Jie Tang
2016-10-18 12:05:01 -07:00
parent 8aa22a13b9
commit 2dc3e56ac2

View File

@@ -31,12 +31,12 @@ def upload(training_dir, algorithm_id=None, writeup=None, benchmark_id=None, api
# We're uploading a benchmark run. # We're uploading a benchmark run.
directories = [] directories = []
env_ids_and_seeds = [] env_ids = []
for name, _, files in os.walk(training_dir): for name, _, files in os.walk(training_dir):
manifests = monitoring.detect_training_manifests(name, files=files) manifests = monitoring.detect_training_manifests(name, files=files)
if manifests: if manifests:
env_info, main_seeds = monitoring.load_env_seed_info_from_manifests(manifests, training_dir) env_info, main_seeds = monitoring.load_env_seed_info_from_manifests(manifests, training_dir)
env_ids_and_seeds.extend((env_info['env_id'], seed) for seed in main_seeds) env_ids.extend(env_info['env_id'] for seed in main_seeds)
directories.append(name) directories.append(name)
# Validate against benchmark spec # Validate against benchmark spec
@@ -45,11 +45,12 @@ def upload(training_dir, algorithm_id=None, writeup=None, benchmark_id=None, api
except error.UnregisteredBenchmark as e: except error.UnregisteredBenchmark as e:
raise error.Error("Invalid benchmark id: {}. Are you using a benchmark registered in gym/benchmarks/__init__.py?".format(benchmark_id)) raise error.Error("Invalid benchmark id: {}. Are you using a benchmark registered in gym/benchmarks/__init__.py?".format(benchmark_id))
spec_env_ids_and_seeds = [(task[0].env_id, seed) for task in spec.task_groups.values() for seed in range(task[0].seeds)] # just verify that the number of seeds match for now
spec_env_ids = [task[0].env_id for task in spec.task_groups.values() for _ in range(task[0].seeds)]
# This could be more stringent about mixing evaluations # This could be more stringent about mixing evaluations
if set(env_ids_and_seeds) != set(spec_env_ids_and_seeds): if sorted(env_ids) != sorted(spec_env_ids):
raise error.Error("Evaluations do not match spec for benchmark {}. We found {}, expected {}".format(benchmark_id, sorted(env_ids_and_seeds), sorted(spec_env_ids_and_seeds))) raise error.Error("Evaluations do not match spec for benchmark {}. We found {}, expected {}".format(benchmark_id, sorted(env_ids), sorted(spec_env_ids)))
benchmark_run = resource.BenchmarkRun.create(benchmark_id=benchmark_id, algorithm_id=algorithm_id) benchmark_run = resource.BenchmarkRun.create(benchmark_id=benchmark_id, algorithm_id=algorithm_id)
benchmark_run_id = benchmark_run.id benchmark_run_id = benchmark_run.id