mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-20 05:52:03 +00:00
Live Rewards Graph Option (#80)
* Adding an option to display a realtime plot of rewards using matplotlib * Updating monitor back to where it was * Adding a live_plot tool, also added an example (fee free to remove it)
This commit is contained in:
committed by
Greg Brockman
parent
32ecb74aa8
commit
7c530804cc
65
examples/agents/random_agent_live_plot.py
Normal file
65
examples/agents/random_agent_live_plot.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
import os, sys
|
||||
|
||||
import gym
|
||||
from gym.monitoring.live_plot import LivePlot
|
||||
|
||||
# The world's simplest agent!
|
||||
class RandomAgent(object):
|
||||
def __init__(self, action_space):
|
||||
self.action_space = action_space
|
||||
|
||||
def act(self, observation, reward, done):
|
||||
return self.action_space.sample()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# You can optionally set up the logger. Also fine to set the level
|
||||
# to logging.DEBUG or logging.WARN if you want to change the
|
||||
# amount of output.
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
env = gym.make('CartPole-v0' if len(sys.argv)<2 else sys.argv[1])
|
||||
|
||||
# You provide the directory to write to (can be an existing
|
||||
# directory, including one with existing data -- all monitor files
|
||||
# will be namespaced). You can also dump to a tempdir if you'd
|
||||
# like: tempfile.mkdtemp().
|
||||
outdir = '/tmp/random-agent-results'
|
||||
env.monitor.start(outdir, force=True, seed=0)
|
||||
|
||||
# You may optionally include a LivePlot so that you can see
|
||||
# how your agent is performing. Use plotter.plot() to update
|
||||
# the graph.
|
||||
plotter = LivePlot(outdir)
|
||||
|
||||
# This declaration must go *after* the monitor call, since the
|
||||
# monitor's seeding creates a new action_space instance with the
|
||||
# appropriate pseudorandom number generator.
|
||||
agent = RandomAgent(env.action_space)
|
||||
|
||||
episode_count = 100
|
||||
max_steps = 200
|
||||
reward = 0
|
||||
done = False
|
||||
|
||||
for i in range(episode_count):
|
||||
ob = env.reset()
|
||||
|
||||
for j in range(max_steps):
|
||||
action = agent.act(ob, reward, done)
|
||||
ob, reward, done, _ = env.step(action)
|
||||
if done:
|
||||
break
|
||||
|
||||
plotter.plot()
|
||||
env.render()
|
||||
|
||||
|
||||
# Dump result info to disk
|
||||
env.monitor.close()
|
||||
|
||||
# Upload to the scoreboard. We could also do this from another
|
||||
# process if we wanted.
|
||||
logger.info("Successfully ran RandomAgent. Now trying to upload results to the scoreboard. If it breaks, you can always just try re-uploading the same results.")
|
||||
gym.upload(outdir, algorithm_id='random')
|
@@ -1,3 +1,4 @@
|
||||
from gym.monitoring.monitor import Monitor, load_results, _open_monitors
|
||||
from gym.monitoring.stats_recorder import StatsRecorder
|
||||
from gym.monitoring.video_recorder import VideoRecorder
|
||||
from gym.monitoring.live_plot import LivePlot
|
||||
|
39
gym/monitoring/live_plot.py
Normal file
39
gym/monitoring/live_plot.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import gym
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class LivePlot(object):
|
||||
def __init__(self, outdir, data_key='episode_rewards', line_color='blue'):
|
||||
"""
|
||||
Liveplot renders a graph of either episode_rewards or episode_lengths
|
||||
|
||||
Args:
|
||||
outdir (outdir): Monitor output file location used to populate the graph
|
||||
data_key (Optional[str]): The key in the json to graph (episode_rewards or episode_lengths).
|
||||
line_color (Optional[dict]): Color of the plot.
|
||||
"""
|
||||
self.outdir = outdir
|
||||
self._last_data = None
|
||||
self.data_key = data_key
|
||||
self.line_color = line_color
|
||||
|
||||
#styling options
|
||||
matplotlib.rcParams['toolbar'] = 'None'
|
||||
plt.style.use('ggplot')
|
||||
plt.xlabel("")
|
||||
plt.ylabel(data_key)
|
||||
fig = plt.gcf().canvas.set_window_title('')
|
||||
|
||||
def plot(self):
|
||||
results = gym.monitoring.monitor.load_results(self.outdir)
|
||||
data = results[self.data_key]
|
||||
|
||||
#only update plot if data is different (plot calls are expensive)
|
||||
if data != self._last_data:
|
||||
self._last_data = data
|
||||
plt.plot(data, color=self.line_color)
|
||||
|
||||
# pause so matplotlib will display
|
||||
# may want to figure out matplotlib animation or use a different library in the future
|
||||
plt.pause(0.000001)
|
@@ -1,3 +1,4 @@
|
||||
numpy>=1.10.4
|
||||
requests>=2.0
|
||||
six
|
||||
matplotlib
|
Reference in New Issue
Block a user