"""This module introduces predefined callbacks for the GEM environment."""
from gym_electric_motor.reference_generators import (
SubepisodedReferenceGenerator,
SwitchedReferenceGenerator,
)
from .core import Callback
[docs]class RampingLimitMargin(Callback):
"""Callback used to adapt the limit margin of a reference generator during runtime.
Supports all
:mod:`~gym_electric_motor.reference_generators.subepisoded_reference_generator.SubepisodedReferenceGenerator`
and :mod:`~gym_electric_motor.reference_generators.subepisoded_reference_generator.SwitchedReferenceGenerator`
with only
:mod:`~gym_electric_motor.reference_generators.subepisoded_reference_generator.SubepisodedReferenceGenerator`
as sub generators.
"""
__CLASS_ERROR__ = (
"The RampingLimitMargin does only support the SubepisodedReferenceGenerator as reference generator or "
"SwitchedReferenceGenerator with SubepisodedReferenceGenerator as all sub reference generators"
)
def __init__(
self,
initial_limit_margin=(-0.1, 0.1),
maximum_limit_margin=(-1, 1),
step_size=0.1,
update_time="episode",
update_freq=10,
):
"""
Args:
initial_limit_margin(tuple(floats)): The initial limit margin which gets updated by AdaptiveLimitMargin
until it reaches maximum_limit_margin
maximum_limit_margin(tuple(floats)): The maximum limit margin. This will be the limit margin after
AdaptiveLimitMargin's last update
step_size(float): The value by which each limit gets updated at each step
update_time(string): When the update happens. "step" for the end of a step, "episode"
for the end of an episode
update_freq(int): After how many cumulative units of update_time an update occurs
Additional Notes:
All limit_margins should be between -1 and 1
"""
super().__init__()
assert update_time in [
"step",
"episode",
], "Chose an option of either 'step' or 'episode' for updating on cumulative steps or episodes"
assert (
initial_limit_margin[1] > initial_limit_margin[0]
), "First element of limit margin has to be smaller than second"
assert (
maximum_limit_margin[1] > maximum_limit_margin[0]
), "First element of limit margin has to be smaller than second"
assert initial_limit_margin[0] >= -1, "Lower limit margin has to be bigger than or equal to -1"
assert maximum_limit_margin[0] >= -1, "Lower limit margin has to be bigger than or equal to -1"
assert initial_limit_margin[1] <= 1, "Upper limit margin has to be smaller than or equal to 1"
assert maximum_limit_margin[1] <= 1, "Upper limit margin has to be smaller than or equal to 1"
self._limit_margin = initial_limit_margin
self._maximum_limit_margin = maximum_limit_margin
self._step_size = step_size
self._update_time = update_time
if self._update_time == "step":
self._step = 0
else:
self._episode = 0
self._update_freq = update_freq
[docs] def set_env(self, env):
# See docstring of superclass
# Assertions added to check for the right reference generator
if isinstance(env.reference_generator, SwitchedReferenceGenerator):
for sub_generator in env.reference_generator._sub_generators:
assert issubclass(type(sub_generator), SubepisodedReferenceGenerator), self.__CLASS_ERROR__
else:
assert issubclass(type(env.reference_generator), SubepisodedReferenceGenerator), self.__CLASS_ERROR__
self._env = env
# Initial image margin added to the reference generator
if isinstance(env.reference_generator, SwitchedReferenceGenerator):
for sub_generator in self._env.reference_generator._sub_generators:
sub_generator._limit_margin = self._limit_margin
else:
self._env.reference_generator._limit_margin = self._limit_margin
[docs] def on_step_end(self, k, state, reference, reward, terminated):
# See docstring of superclass
if self._update_time == "step":
self._step += 1
if self._step % self._update_freq == 0:
self._step = 0
self._update_limit_margin()
[docs] def on_reset_end(self, state, reference):
# See docstring of superclass
if self._update_time == "episode":
self._episode += 1
if self._episode % self._update_freq == 0:
self._episode = 0
self._update_limit_margin()
def _update_limit_margin(self):
"""Updates the limit margin of the environments according to the step size and maximum limit margin"""
if self._limit_margin != self._maximum_limit_margin:
new_lower_limit = self._limit_margin[0] - self._step_size
new_lower_limit = (
new_lower_limit if new_lower_limit > self._maximum_limit_margin[0] else self._maximum_limit_margin[0]
)
new_upper_limit = self._limit_margin[1] + self._step_size
new_upper_limit = (
new_upper_limit if new_upper_limit < self._maximum_limit_margin[1] else self._maximum_limit_margin[1]
)
self._limit_margin = (new_lower_limit, new_upper_limit)
if isinstance(self._env.reference_generator, SwitchedReferenceGenerator):
for sub_generator in self._env.reference_generator._sub_generators:
sub_generator._limit_margin = self._limit_margin
else:
self._env.reference_generator._limit_margin = self._limit_margin