diff --git a/gym/monitoring/monitor_manager.py b/gym/monitoring/monitor_manager.py index b79153161..e53e32cf9 100644 --- a/gym/monitoring/monitor_manager.py +++ b/gym/monitoring/monitor_manager.py @@ -8,6 +8,7 @@ import six from gym import error, version from gym.monitoring import stats_recorder, video_recorder from gym.utils import atomic_write, closer +from gym.utils.json_utils import json_encode_np logger = logging.getLogger(__name__) @@ -180,7 +181,7 @@ class MonitorManager(object): 'videos': [(os.path.basename(v), os.path.basename(m)) for v, m in self.videos], 'env_info': self._env_info(), - }, f, default=json_encode) + }, f, default=json_encode_np) def close(self): """Flush all monitor data to disk and close any open rending windows.""" @@ -408,21 +409,3 @@ def collapse_env_infos(env_infos, training_dir): if key not in first: raise error.Error("env_info {} from training directory {} is missing expected key {}. This is unexpected and likely indicates a bug in gym.".format(first, training_dir, key)) return first - - -def json_encode(obj): - """ - JSON can't serialize numpy types, so convert to pure python - """ - if isinstance(obj, np.ndarray): - return list(obj) - elif isinstance(obj, np.float32): - return float(obj) - elif isinstance(obj, np.float64): - return float(obj) - elif isinstance(obj, np.int32): - return int(obj) - elif isinstance(obj, np.int64): - return int(obj) - else: - return obj diff --git a/gym/monitoring/stats_recorder.py b/gym/monitoring/stats_recorder.py index 44c87c87d..998c89b1b 100644 --- a/gym/monitoring/stats_recorder.py +++ b/gym/monitoring/stats_recorder.py @@ -4,6 +4,7 @@ import time from gym import error from gym.utils import atomic_write +from gym.utils.json_utils import json_encode_np class StatsRecorder(object): def __init__(self, directory, file_prefix, autoreset=False, env_id=None): @@ -99,4 +100,4 @@ class StatsRecorder(object): 'episode_lengths': self.episode_lengths, 'episode_rewards': self.episode_rewards, 'episode_types': self.episode_types, - }, f) + }, f, default=json_encode_np) diff --git a/gym/utils/json_utils.py b/gym/utils/json_utils.py new file mode 100644 index 000000000..6088d4ea9 --- /dev/null +++ b/gym/utils/json_utils.py @@ -0,0 +1,18 @@ +import numpy as np + +def json_encode_np(obj): + """ + JSON can't serialize numpy types, so convert to pure python + """ + if isinstance(obj, np.ndarray): + return list(obj) + elif isinstance(obj, np.float32): + return float(obj) + elif isinstance(obj, np.float64): + return float(obj) + elif isinstance(obj, np.int32): + return int(obj) + elif isinstance(obj, np.int64): + return int(obj) + else: + return obj