Refactor json_encode_np

This commit is contained in:
Trevor Blackwell
2016-12-30 12:06:07 -08:00
parent 1059ccdf0a
commit ae8956921b
3 changed files with 22 additions and 20 deletions

View File

@@ -8,6 +8,7 @@ import six
from gym import error, version
from gym.monitoring import stats_recorder, video_recorder
from gym.utils import atomic_write, closer
from gym.utils.json_utils import json_encode_np
logger = logging.getLogger(__name__)
@@ -180,7 +181,7 @@ class MonitorManager(object):
'videos': [(os.path.basename(v), os.path.basename(m))
for v, m in self.videos],
'env_info': self._env_info(),
}, f, default=json_encode)
}, f, default=json_encode_np)
def close(self):
"""Flush all monitor data to disk and close any open rending windows."""
@@ -408,21 +409,3 @@ def collapse_env_infos(env_infos, training_dir):
if key not in first:
raise error.Error("env_info {} from training directory {} is missing expected key {}. This is unexpected and likely indicates a bug in gym.".format(first, training_dir, key))
return first
def json_encode(obj):
"""
JSON can't serialize numpy types, so convert to pure python
"""
if isinstance(obj, np.ndarray):
return list(obj)
elif isinstance(obj, np.float32):
return float(obj)
elif isinstance(obj, np.float64):
return float(obj)
elif isinstance(obj, np.int32):
return int(obj)
elif isinstance(obj, np.int64):
return int(obj)
else:
return obj

View File

@@ -4,6 +4,7 @@ import time
from gym import error
from gym.utils import atomic_write
from gym.utils.json_utils import json_encode_np
class StatsRecorder(object):
def __init__(self, directory, file_prefix, autoreset=False, env_id=None):
@@ -99,4 +100,4 @@ class StatsRecorder(object):
'episode_lengths': self.episode_lengths,
'episode_rewards': self.episode_rewards,
'episode_types': self.episode_types,
}, f)
}, f, default=json_encode_np)

18
gym/utils/json_utils.py Normal file
View File

@@ -0,0 +1,18 @@
import numpy as np
def json_encode_np(obj):
"""
JSON can't serialize numpy types, so convert to pure python
"""
if isinstance(obj, np.ndarray):
return list(obj)
elif isinstance(obj, np.float32):
return float(obj)
elif isinstance(obj, np.float64):
return float(obj)
elif isinstance(obj, np.int32):
return int(obj)
elif isinstance(obj, np.int64):
return int(obj)
else:
return obj