import numpy as np
from .utils import set_state_array
class Constraint:
"""Base class for all constraints in the ConstraintMonitor."""
def __call__(self, state):
"""Function that is called to check the constraint.
Args:
state(numpy.ndarray(float)): The current physical systems state.
Returns:
float in [0.0, 1.0]: Degree how much the constraint has been violated.
0.0: No violation
(0.0, 1.0): Undesired zone near to a full violation. No episode termination.
1.0: Full violation and episode termination.
"""
raise NotImplementedError
def set_modules(self, ps):
"""Called by the environment that the Constraint can read information from the PhysicalSystem.
Args:
ps(PhysicalSystem): PhysicalSystem of the environment.
"""
pass
[docs]class LimitConstraint(Constraint):
"""Constraint to observe the limits on one or more system states.
This constraint observes if any of the systems state values exceeds the limit specified in the PhysicalSystem.
.. math::
1.0 >= s_i / s_{i,max}
For all :math:`i` in the set of PhysicalSystems states :math:`S`.
"""
def __init__(self, observed_state_names="all_states"):
"""
Args:
observed_state_names(['all_states']/iterable(str)): The states to observe. \n
- ['all_states']: Shortcut for observing all states.
- iterable(str): Pass an iterable containing all state names of the states to observe.
"""
self._observed_state_names = observed_state_names
self._limits = None
self._observed_states = None
def __call__(self, state):
observed = state[self._observed_states]
violation = any(abs(observed) > 1.0)
return float(violation)
[docs] def set_modules(self, ps):
self._limits = ps.limits
if "all_states" in self._observed_state_names:
self._observed_state_names = ps.state_names
if self._observed_state_names is None:
self._observed_state_names = []
self._observed_states = set_state_array(dict.fromkeys(self._observed_state_names, 1), ps.state_names).astype(
bool
)
[docs]class SquaredConstraint(Constraint):
"""A squared constraint on multiple states as it is required oftentimes for the dq-currents in synchronous and
asynchronous electric motors.
.. math::
1.0 <= \sum_{i \in S} (s_i / s_{i,max})^2
:math:`S`: Set of the observed PhysicalSystems states
"""
def __init__(self, states=()):
"""
Args:
states(iterable(str)): Names of all states to be observed within the SquaredConstraint.
"""
self._states = states
self._state_indices = ()
self._limits = ()
self._normalized = False
[docs] def set_modules(self, ps):
self._state_indices = [ps.state_positions[state] for state in self._states]
self._limits = ps.limits[self._state_indices]
self._normalized = not np.all(ps.state_space.high[self._state_indices] == self._limits)
def __call__(self, state):
state_ = state[self._state_indices] if self._normalized else state[self._state_indices] / self._limits
return float(np.sum(state_**2) > 1.0)