import os
import time
from enum import Enum
import gymnasium
import matplotlib
import matplotlib.pyplot as plt
from gym_electric_motor.core import ElectricMotorVisualization
from .motor_dashboard_plots import ActionPlot, EpisodePlot, RewardPlot, StatePlot, StepPlot, TimePlot
from .render_modes import RenderMode
class MotorDashboardLegacy(ElectricMotorVisualization):
"""A dashboard to plot the GEM states into graphs.
Every MotorDashboard consists of multiple MotorDashboardPlots that are each responsible for the plots in a single
matplotlib axis.
It handles three different types of plots: The TimePlot, EpisodePlot and StepPlot which especially differ in
their x-Axis. The time plots plot every step and have got the time on the x-Axis. The EpisodicPlots plot statistics
over the episodes (e.g. mean reward per step in each episode). The episode number is on their x-Axis. The
StepPlots plot statistics over the last taken steps (e.g. mean reward over the last 1000 steps) and their x-Axis
are the cumulative number of steps.
The StepPlots, EpisodicPlots and TimePlots each are plotted into three separate figures.
The most common TimePlots (i.e to plot the states, actions and rewards) can be plotted by just passing the
corresponding arguments in the constructor. Additional plots (e.g. the MeanEpisodeRewardPlot) have to be
initialized manually and passed to the constructor.
Furthermore, completely custom plots can be defined. They have to derive from the TimePlot, EpisodePlot or
StepPlot base classes.
"""
@property
def update_interval(self):
"""Number of steps until the visualization is updated"""
return self._update_interval
def __init__(
self,
state_plots=(),
action_plots=(),
reward_plot=False,
additional_plots=(),
update_interval=1000,
time_plot_width=10000,
style=None,
scale_plots=None,
):
"""
Args:
state_plots('all'/iterable(str)): An iterable of state names to be shown. If 'all' all states will be shown.
Default: () (no plotted states)
action_plots('all'/iterable(int)): If action_plots='all', all actions will be plotted. If more than one
action can be applied on the environment it can be selected by its index.
Default: () (no plotted actions).
reward_plot(boolean): Select if the current reward is to be plotted. Default: False
additional_plots(iterable((TimePlot/EpisodePlot/StepPlot))): Additional already instantiated plots
to be shown on the dashboard
update_interval(int > 0): Amount of steps after which the plots are updated. Updating each step reduces the
performance drastically. Default: 1000
time_plot_width(int > 0): Width of the step plots in steps. Default: 10000 steps
(1 second for continuously controlled environments / 0.1 second for discretely controlled environments)
style(string): Select one of the matplotlib-styles. e.g. "dark-background".
Default: None (the already selected style)
"""
# Basic assertions
assert isinstance(reward_plot, bool)
assert all(isinstance(ap, (TimePlot, EpisodePlot, StepPlot)) for ap in additional_plots)
assert type(update_interval) in [int, float]
assert update_interval > 0
assert type(time_plot_width) in [int, float]
assert time_plot_width > 0
assert style in plt.style.available or style is None
super().__init__()
# Select the matplotlib style
if style is not None:
plt.style.use(style)
# List of the opened figures
self._figures = []
# The figures to be opened for the step plots, episodic plots and step plots
self._time_plot_figure = None
self._episodic_plot_figure = None
self._step_plot_figure = None
# Store the input data
self._state_plots = state_plots
self._action_plots = action_plots
self._reward_plot = reward_plot
# Separate the additional plots into StepPlots, EpisodicPlots and StepPlots
self._custom_time_plots = [p for p in additional_plots if isinstance(p, TimePlot)]
self._episodic_plots = [p for p in additional_plots if isinstance(p, EpisodePlot)]
self._step_plots = [p for p in additional_plots if isinstance(p, StepPlot)]
self._time_plots = []
self._update_interval = int(update_interval)
self._time_plot_width = int(time_plot_width)
self._plots = []
self._k = 0
self._update_render = False
# self._scale_plots = scale_plots
def on_reset_begin(self):
"""Called before the environment is reset. All subplots are reset."""
for plot in self._plots:
plot.on_reset_begin()
def on_reset_end(self, state, reference):
"""Called after the environment is reset. The initial data is passed.
Args:
state(array(float)): The initial state :math:`s_0`.
reference(array(float)): The initial reference for the first time step :math:`s^*_0`.
"""
for plot in self._plots:
plot.on_reset_end(state, reference)
def on_step_begin(self, k, action):
"""The information about the last environmental step is passed.
Args:
k(int): The current episode step.
action(ndarray(float) / int): The taken action :math:`a_k`.
"""
for plot in self._plots:
plot.on_step_begin(k, action)
def on_step_end(self, k, state, reference, reward, terminated):
"""The information after the step is passed
Args:
k(int): The current episode step
state(array(float)): The state of the env after the step :math:`s_k`.
reference(array(float)): The reference corresponding to the state :math:`s^*_k`.
reward(float): The reward that has been received for the last action that lead to the current state
:math:`r_{k}`.
terminated(bool): Flag, that indicates, if the last action lead to a terminal state :math:`t_{k}`.
"""
for plot in self._plots:
plot.on_step_end(k, state, reference, reward, terminated)
self._k += 1
if self._k % self._update_interval == 0:
self._update_render = True
def render(self):
"""Updates the plots every *update cycle* calls of this method."""
if (
not (self._time_plot_figure or self._episodic_plot_figure or self._step_plot_figure)
and len(self._plots) > 0
):
self.initialize()
if self._update_render:
self._update()
self._update_render = False
def set_env(self, env):
"""Called during initialization of the environment to interconnect all modules. State names, references,...
might be saved here for later processing
Args:
env(ElectricMotorEnvironment): The environment.
"""
state_names = env.physical_system.state_names
if self._state_plots == "all":
self._state_plots = state_names
if self._action_plots == "all":
if type(env.action_space) is gymnasium.spaces.Discrete:
self._action_plots = [0]
elif type(env.action_space) in (
gymnasium.spaces.Box,
gymnasium.spaces.MultiDiscrete,
):
self._action_plots = list(range(env.action_space.shape[0]))
self._time_plots = []
if len(self._state_plots) > 0:
assert all(state in state_names for state in self._state_plots)
for state in self._state_plots:
self._time_plots.append(StatePlot(state))
if len(self._action_plots) > 0:
assert type(env.action_space) in (
gymnasium.spaces.Box,
gymnasium.spaces.Discrete,
gymnasium.spaces.MultiDiscrete,
), f"Action space of type {type(env.action_space)} not supported for plotting."
for action in self._action_plots:
ap = ActionPlot(action)
self._time_plots.append(ap)
if self._reward_plot:
self._reward_plot = RewardPlot()
self._time_plots.append(self._reward_plot)
self._time_plots += self._custom_time_plots
self._plots = self._time_plots + self._episodic_plots + self._step_plots
for time_plot in self._time_plots:
time_plot.set_width(self._time_plot_width)
for plot in self._plots:
plot.set_env(env)
def reset_figures(self):
"""Method to reset the figures to the initial state.
This method can be called, when the plots shall be reset after the training and before the test, for example.
Another use case, that requires the call of this method by the user, is when the dashboard is executed within
a jupyter notebook and the figures shall be plotted below a new cell."""
for plot in self._plots:
plot.reset_data()
self._episodic_plot_figure = self._time_plot_figure = self._step_plot_figure = None
self._figures = []
def initialize(self):
"""Called with first render() call to setup the figures and plots."""
plt.close()
self._figures = []
if plt.get_backend() in ["nbAgg", "module://ipympl.backend_nbagg"]:
self._initialize_figures_notebook()
else:
self._initialize_figures_window()
plt.pause(0.1)
def _initialize_figures_notebook(self):
# Create all plots below each other: First Time then Episode then Step Plots
no_of_plots = len(self._episodic_plots) + len(self._step_plots) + len(self._time_plots)
if no_of_plots == 0:
return
fig, axes = plt.subplots(no_of_plots, figsize=(8, 2 * no_of_plots))
self._figures = [fig]
axes = [axes] if no_of_plots == 1 else axes
time_axes = axes[: len(self._time_plots)]
axes = axes[len(self._time_plots) :]
if len(self._time_plots) > 0:
time_axes[-1].set_xlabel("t/s")
self._time_plot_figure = fig
for plot, axis in zip(self._time_plots, time_axes):
plot.initialize(axis)
episode_axes = axes[: len(self._episodic_plots)]
axes = axes[len(self._episodic_plots) :]
if len(self._episodic_plots) > 0:
episode_axes[-1].set_xlabel("Episode No")
self._episodic_plot_figure = fig
for plot, axis in zip(self._episodic_plots, episode_axes):
plot.initialize(axis)
step_axes = axes
if len(self._step_plots) > 0:
step_axes[-1].set_xlabel("Cumulative Steps")
self._step_plot_figure = fig
for plot, axis in zip(self._step_plots, step_axes):
plot.initialize(axis)
def _initialize_figures_window(self):
# create separate figures for time based, step and episode based plots
if len(self._episodic_plots) > 0:
self._episodic_plot_figure, axes_ep = plt.subplots(len(self._episodic_plots), sharex=True)
axes_ep = [axes_ep] if len(self._episodic_plots) == 1 else axes_ep
self._episodic_plot_figure.subplots_adjust(wspace=0.0, hspace=0.02)
self._episodic_plot_figure.canvas.manager.set_window_title("Episodic Plots")
axes_ep[-1].set_xlabel("Episode No")
self._figures.append(self._episodic_plot_figure)
for plot, axis in zip(self._episodic_plots, axes_ep):
plot.initialize(axis)
if len(self._step_plots) > 0:
self._step_plot_figure, axes_int = plt.subplots(len(self._step_plots), sharex=True)
axes_int = [axes_int] if len(self._step_plots) == 1 else axes_int
self._step_plot_figure.canvas.manager.set_window_title("Step Plots")
self._step_plot_figure.subplots_adjust(wspace=0.0, hspace=0.02)
axes_int[-1].set_xlabel("Cumulative Steps")
self._figures.append(self._step_plot_figure)
for plot, axis in zip(self._step_plots, axes_int):
plot.initialize(axis)
if len(self._time_plots) > 0:
self._time_plot_figure, axes_step = plt.subplots(len(self._time_plots), sharex=True)
self._time_plot_figure.canvas.manager.set_window_title("Time Plots")
axes_step = [axes_step] if len(self._time_plots) == 1 else axes_step
self._time_plot_figure.subplots_adjust(wspace=0.0, hspace=0.2)
axes_step[-1].set_xlabel("$t$/s")
self._figures.append(self._time_plot_figure)
for plot, axis in zip(self._time_plots, axes_step):
plot.initialize(axis)
self._figures[0].align_ylabels()
def _update(self):
"""Called every *update cycle* steps to refresh the figure."""
for plot in self._plots:
plot.render()
for fig in self._figures:
# fig.align_ylabels()
fig.canvas.draw()
fig.canvas.flush_events()
# Proxy Object for Refactoring
[docs]class MotorDashboard(MotorDashboardLegacy):
render_mode = None
def __init__(self, render_mode=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.render_mode = render_mode
# self.on_reset_begin = None
[docs] def set_env(self, env):
# This is the only data we need from the environment
def myenv():
return None
myenv.physical_system = lambda: None
myenv.physical_system.state_names = env.physical_system.state_names
myenv.physical_system.tau = env.physical_system.tau
myenv.physical_system.state_positions = env.physical_system.state_positions
myenv.physical_system.limits = env.physical_system.limits
myenv.physical_system.state_space = env.physical_system.state_space
myenv.physical_system.action_space = env.physical_system.action_space
# myenv._plots = env._plots
myenv.reference_generator = env.reference_generator
myenv.reward_range = env.reward_range
myenv.scale_plots = env.scale_plots
myenv.action_space = env.action_space
super().set_env(myenv)
[docs] def on_close(self):
if self.render_mode == RenderMode.FigureOnce:
self.render()
super().on_close()
[docs] def on_step_end(self, k, state, reference, reward, terminated):
super().on_step_end(k, state, reference, reward, terminated)
if self.render_mode == RenderMode.Figure:
self.render()
def show(self):
plt.show(block=False)
def show_and_hold(self):
self.force_render()
plt.show(block=True)
def force_render(self):
self._update_render = True
self.render()
def save_to_file(self, filename=None, academic_mode=False):
if filename is None:
timestamp_string = time.strftime("%Y%m%d-%H%M%S")
filename = f"gem_plot_{timestamp_string}"
# Academic Mode (latex font), needs some prerequisites to be installed
if academic_mode:
matplotlib.rcParams.update(
{
"text.usetex": True,
"font.family": "sans-serif",
"font.sans-serif": "Helvetica",
}
)
self.force_render()
if academic_mode:
self._save_fig(filename, filetype="pdf")
else:
self._save_fig(filename, filetype="png")
def _save_fig(self, filename, filetype):
"""Save figure with timestamped as filename"""
# create output folder if it not exists
output_folder_name = "saved_plots"
if not os.path.exists(output_folder_name):
print(f"Creating output folder for plots: {output_folder_name}")
os.makedirs(output_folder_name)
filepath = f"{output_folder_name}/{filename}.{filetype}"
print(f"Saved figure to file: {filepath}")
self.figure().savefig(filepath, dpi=300)