From 1c872ca8fd61fd660c7f9be3f02ad33e6db65cfb Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Fri, 31 May 2019 15:36:20 -0700 Subject: [PATCH] run test_monitor through pytest; fix the test, add flake8 to bench direectory - like PR 891 (#921) --- baselines/bench/__init__.py | 1 + baselines/bench/benchmarks.py | 1 - baselines/bench/monitor.py | 26 -------------------------- baselines/bench/test_monitor.py | 31 +++++++++++++++++++++++++++++++ setup.cfg | 1 - 5 files changed, 32 insertions(+), 28 deletions(-) create mode 100644 baselines/bench/test_monitor.py diff --git a/baselines/bench/__init__.py b/baselines/bench/__init__.py index 4cbd5bb..fc2e05b 100644 --- a/baselines/bench/__init__.py +++ b/baselines/bench/__init__.py @@ -1,2 +1,3 @@ +# flake8: noqa F403 from baselines.bench.benchmarks import * from baselines.bench.monitor import * diff --git a/baselines/bench/benchmarks.py b/baselines/bench/benchmarks.py index c381935..5d626de 100644 --- a/baselines/bench/benchmarks.py +++ b/baselines/bench/benchmarks.py @@ -1,5 +1,4 @@ import re -import os.path as osp import os SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/baselines/bench/monitor.py b/baselines/bench/monitor.py index 1281f9a..e63e71d 100644 --- a/baselines/bench/monitor.py +++ b/baselines/bench/monitor.py @@ -1,13 +1,11 @@ __all__ = ['Monitor', 'get_monitor_files', 'load_results'] -import gym from gym.core import Wrapper import time from glob import glob import csv import os.path as osp import json -import numpy as np class Monitor(Wrapper): EXT = "monitor.csv" @@ -162,27 +160,3 @@ def load_results(dir): df['t'] -= min(header['t_start'] for header in headers) df.headers = headers # HACK to preserve backwards compatibility return df - -def test_monitor(): - env = gym.make("CartPole-v1") - env.seed(0) - mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4() - menv = Monitor(env, mon_file) - menv.reset() - for _ in range(1000): - _, _, done, _ = menv.step(0) - if done: - menv.reset() - - f = open(mon_file, 'rt') - - firstline = f.readline() - assert firstline.startswith('#') - metadata = json.loads(firstline[1:]) - assert metadata['env_id'] == "CartPole-v1" - assert set(metadata.keys()) == {'env_id', 'gym_version', 't_start'}, "Incorrect keys in monitor metadata" - - last_logline = pandas.read_csv(f, index_col=None) - assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline" - f.close() - os.remove(mon_file) diff --git a/baselines/bench/test_monitor.py b/baselines/bench/test_monitor.py new file mode 100644 index 0000000..093f9c6 --- /dev/null +++ b/baselines/bench/test_monitor.py @@ -0,0 +1,31 @@ +from .monitor import Monitor +import gym +import json + +def test_monitor(): + import pandas + import os + import uuid + + env = gym.make("CartPole-v1") + env.seed(0) + mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4() + menv = Monitor(env, mon_file) + menv.reset() + for _ in range(1000): + _, _, done, _ = menv.step(0) + if done: + menv.reset() + + f = open(mon_file, 'rt') + + firstline = f.readline() + assert firstline.startswith('#') + metadata = json.loads(firstline[1:]) + assert metadata['env_id'] == "CartPole-v1" + assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata" + + last_logline = pandas.read_csv(f, index_col=None) + assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline" + f.close() + os.remove(mon_file) diff --git a/setup.cfg b/setup.cfg index 20d822e..0cd564a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,4 +4,3 @@ exclude = .git, __pycache__, baselines/ppo1, - baselines/bench,