mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-27 16:57:10 +00:00
Refactor json_encode_np
This commit is contained in:
@@ -8,6 +8,7 @@ import six
|
|||||||
from gym import error, version
|
from gym import error, version
|
||||||
from gym.monitoring import stats_recorder, video_recorder
|
from gym.monitoring import stats_recorder, video_recorder
|
||||||
from gym.utils import atomic_write, closer
|
from gym.utils import atomic_write, closer
|
||||||
|
from gym.utils.json_utils import json_encode_np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -180,7 +181,7 @@ class MonitorManager(object):
|
|||||||
'videos': [(os.path.basename(v), os.path.basename(m))
|
'videos': [(os.path.basename(v), os.path.basename(m))
|
||||||
for v, m in self.videos],
|
for v, m in self.videos],
|
||||||
'env_info': self._env_info(),
|
'env_info': self._env_info(),
|
||||||
}, f, default=json_encode)
|
}, f, default=json_encode_np)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Flush all monitor data to disk and close any open rending windows."""
|
"""Flush all monitor data to disk and close any open rending windows."""
|
||||||
@@ -408,21 +409,3 @@ def collapse_env_infos(env_infos, training_dir):
|
|||||||
if key not in first:
|
if key not in first:
|
||||||
raise error.Error("env_info {} from training directory {} is missing expected key {}. This is unexpected and likely indicates a bug in gym.".format(first, training_dir, key))
|
raise error.Error("env_info {} from training directory {} is missing expected key {}. This is unexpected and likely indicates a bug in gym.".format(first, training_dir, key))
|
||||||
return first
|
return first
|
||||||
|
|
||||||
|
|
||||||
def json_encode(obj):
|
|
||||||
"""
|
|
||||||
JSON can't serialize numpy types, so convert to pure python
|
|
||||||
"""
|
|
||||||
if isinstance(obj, np.ndarray):
|
|
||||||
return list(obj)
|
|
||||||
elif isinstance(obj, np.float32):
|
|
||||||
return float(obj)
|
|
||||||
elif isinstance(obj, np.float64):
|
|
||||||
return float(obj)
|
|
||||||
elif isinstance(obj, np.int32):
|
|
||||||
return int(obj)
|
|
||||||
elif isinstance(obj, np.int64):
|
|
||||||
return int(obj)
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
@@ -4,6 +4,7 @@ import time
|
|||||||
|
|
||||||
from gym import error
|
from gym import error
|
||||||
from gym.utils import atomic_write
|
from gym.utils import atomic_write
|
||||||
|
from gym.utils.json_utils import json_encode_np
|
||||||
|
|
||||||
class StatsRecorder(object):
|
class StatsRecorder(object):
|
||||||
def __init__(self, directory, file_prefix, autoreset=False, env_id=None):
|
def __init__(self, directory, file_prefix, autoreset=False, env_id=None):
|
||||||
@@ -99,4 +100,4 @@ class StatsRecorder(object):
|
|||||||
'episode_lengths': self.episode_lengths,
|
'episode_lengths': self.episode_lengths,
|
||||||
'episode_rewards': self.episode_rewards,
|
'episode_rewards': self.episode_rewards,
|
||||||
'episode_types': self.episode_types,
|
'episode_types': self.episode_types,
|
||||||
}, f)
|
}, f, default=json_encode_np)
|
||||||
|
18
gym/utils/json_utils.py
Normal file
18
gym/utils/json_utils.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def json_encode_np(obj):
|
||||||
|
"""
|
||||||
|
JSON can't serialize numpy types, so convert to pure python
|
||||||
|
"""
|
||||||
|
if isinstance(obj, np.ndarray):
|
||||||
|
return list(obj)
|
||||||
|
elif isinstance(obj, np.float32):
|
||||||
|
return float(obj)
|
||||||
|
elif isinstance(obj, np.float64):
|
||||||
|
return float(obj)
|
||||||
|
elif isinstance(obj, np.int32):
|
||||||
|
return int(obj)
|
||||||
|
elif isinstance(obj, np.int64):
|
||||||
|
return int(obj)
|
||||||
|
else:
|
||||||
|
return obj
|
Reference in New Issue
Block a user