"""The parameter class."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from typing import Any
if TYPE_CHECKING:
from glotaran.parameter import ParameterGroup
import asteval
import numpy as np
RESERVED_LABELS = [symbol for symbol in asteval.make_symbol_table()] + ["group"]
[docs]class Keys:
"""Keys for parameter options."""
EXPR = "expr"
MAX = "max"
MIN = "min"
NON_NEG = "non-negative"
VARY = "vary"
[docs]class Parameter:
"""A parameter for optimization."""
_find_parameter = re.compile(r"(\$[\w\d\.]+)")
"""A regexpression to find and replace parameter names in expressions."""
_label_validator_regexp = re.compile(r"\W", flags=re.ASCII)
"""A regexpression to validate labels."""
def __init__(
self,
label: str = None,
full_label: str = None,
expression: str = None,
maximum: int | float = np.inf,
minimum: int | float = -np.inf,
non_negative: bool = False,
value: float = None,
vary: bool = True,
):
"""
Parameters
----------
label :
The label of the parameter.
full_label : str
The label of the parameter with its path in a parameter group prepended.
""" # TODO: update docstring.
self.label = label
self.full_label = full_label
self.expression = expression
self.maximum = maximum
self.minimum = minimum
self.non_negative = non_negative
self.standard_error = 0.0
self.value = value
self.vary = vary
self._transformed_expression = None
[docs] @classmethod
def valid_label(cls, label: str) -> bool:
"""Returns true if the `label` is valid string."""
return cls._label_validator_regexp.search(label) is None and label not in RESERVED_LABELS
[docs] @classmethod
def from_list_or_value(
cls,
value: int | float | list,
default_options: dict = None,
label: str = None,
) -> Parameter:
"""Creates a parameter from a list or numeric value.
Parameters
----------
value :
The list or numeric value.
default_options :
A dictionary of default options.
label :
The label of the parameter.
"""
param = cls(label=label)
options = None
if not isinstance(value, list):
param.value = value
else:
values = _sanatize_parameter_list(value)
param.label = _retrieve_from_list_by_type(values, str, label)
param.value = float(_retrieve_from_list_by_type(values, (int, float), 0))
options = _retrieve_from_list_by_type(values, dict, None)
if default_options:
param._set_options_from_dict(default_options)
if options:
param._set_options_from_dict(options)
return param
[docs] def set_from_group(self, group: ParameterGroup):
"""Sets all values of the parameter to the values of the corresponding parameter in the group.
Notes
-----
For internal use.
Parameters
----------
group :
The :class:`glotaran.parameter.ParameterGroup`.
"""
p = group.get(self.full_label)
self.expression = p.expression
self.maximum = p.maximum
self.minimum = p.minimum
self.non_negative = p.non_negative
self.standard_error = p.standard_error
self.value = p.value
self.vary = p.vary
def _set_options_from_dict(self, options: dict):
if Keys.EXPR in options:
self.expression = options[Keys.EXPR]
if Keys.NON_NEG in options:
self.non_negative = options[Keys.NON_NEG]
if Keys.MAX in options:
self.maximum = options[Keys.MAX]
if Keys.MIN in options:
self.minimum = options[Keys.MIN]
if Keys.VARY in options:
self.vary = options[Keys.VARY]
@property
def label(self) -> str:
"""Label of the parameter"""
return self._label
@label.setter
def label(self, label: str):
if label is not None and not Parameter.valid_label(label):
raise ValueError("'{label}' is not a valid group label.")
self._label = label
@property
def full_label(self) -> str:
"""The label of the parameter with its path in a parameter group prepended."""
return self._full_label
@full_label.setter
def full_label(self, full_label: str):
self._full_label = full_label
@property
def non_negative(self) -> bool:
r"""Indicates if the parameter is non-negativ.
If true, the parameter will be transformed with :math:`p' = \log{p}` and
:math:`p = \exp{p'}`.
Always `False` if `expression` is not `None`.
""" # w605
return self._non_negative if self.expression is None else False
@non_negative.setter
def non_negative(self, non_negative: bool):
self._non_negative = non_negative
@property
def vary(self) -> bool:
"""Indicates if the parameter should be optimized.
Always `False` if `expression` is not `None`.
"""
return self._vary if self.expression is None else False
@vary.setter
def vary(self, vary: bool):
self._vary = vary
@property
def maximum(self) -> float:
"""The upper bound of the parameter."""
return self._maximum
@maximum.setter
def maximum(self, maximum: int | float):
if not isinstance(maximum, float):
try:
maximum = float(maximum)
except Exception:
raise TypeError(
"Parameter maximum must be numeric."
+ f"'{self.full_label}' has maximum '{maximum}' of type '{type(maximum)}'"
)
self._maximum = maximum
@property
def minimum(self) -> float:
"""The lower bound of the parameter."""
return self._minimum
@minimum.setter
def minimum(self, minimum: int | float):
if not isinstance(minimum, float):
try:
minimum = float(minimum)
except Exception:
raise TypeError(
"Parameter minimum must be numeric."
+ f"'{self.full_label}' has minimum '{minimum}' of type '{type(minimum)}'"
)
self._minimum = minimum
@property
def expression(self) -> str:
"""The expression of the parameter.""" # TODO: Formulate better docstring.
return self._expression
@expression.setter
def expression(self, expression: str):
self._expression = expression
self._transformed_expression = None
@property
def transformed_expression(self) -> str:
"""The expression of the parameter transformed for evaluation within a `ParameterGroup`."""
if self.expression is not None and self._transformed_expression is None:
self._transformed_expression = self.expression
for match in Parameter._find_parameter.findall(self._transformed_expression):
self._transformed_expression = self._transformed_expression.replace(
match, f"group.get('{match[1:]}').value"
)
return self._transformed_expression
@property
def standard_error(self) -> float:
"""The standard error of the optimized parameter."""
return self._stderr
@standard_error.setter
def standard_error(self, standard_error: float):
self._stderr = standard_error
@property
def value(self) -> float:
"""The value of the parameter"""
return self._getval()
@value.setter
def value(self, value: int | float):
if not isinstance(value, float) and value is not None:
try:
value = float(value)
except Exception:
raise TypeError(
"Parameter value must be numeric."
+ f"'{self.full_label}' has value '{value}' of type '{type(value)}'"
)
self._value = value
[docs] def get_value_and_bounds_for_optimization(self) -> tuple[float, float, float]:
"""Gets the parameter value and bounds with expression and non-negative constraints
applied."""
value = self.value
minimum = self.minimum
maximum = self.maximum
if self.non_negative:
value = _log_value(value)
minimum = _log_value(minimum)
maximum = _log_value(maximum)
return value, minimum, maximum
[docs] def set_value_from_optimization(self, value: float):
"""Sets the value from an optimization result and reverses non-negative transformation."""
self.value = np.exp(value) if self.non_negative else value
def __getstate__(self):
"""Get state for pickle."""
return (
self.label,
self.full_label,
self.expression,
self.maximum,
self.minimum,
self.non_negative,
self.standard_error,
self.value,
self.vary,
)
def __setstate__(self, state):
"""Set state from pickle."""
(
self.label,
self.full_label,
self.expression,
self.maximum,
self.minimum,
self.non_negative,
self.standard_error,
self.value,
self.vary,
) = state
def _getval(self) -> float:
return self._value
def __repr__(self):
"""String representation """
return (
f"__{self.label}__: _Value_: {self.value}, _StdErr_: {self.standard_error}, _Min_:"
f" {self.minimum}, _Max_: {self.maximum}, _Vary_: {self.vary},"
f" _Non-Negative_: {self.non_negative}, _Expr_: {self.expression}"
)
def __array__(self):
"""array"""
return np.array(float(self._getval()), dtype=float)
def __str__(self):
"""string"""
return self.__repr__()
def __abs__(self):
"""abs"""
return abs(self._getval())
def __neg__(self):
"""neg"""
return -self._getval()
def __pos__(self):
"""positive"""
return +self._getval()
def __int__(self):
"""int"""
return int(self._getval())
def __float__(self):
"""float"""
return float(self._getval())
def __trunc__(self):
"""trunc"""
return self._getval().__trunc__()
def __add__(self, other):
"""+"""
return self._getval() + other
def __sub__(self, other):
"""-"""
return self._getval() - other
def __truediv__(self, other):
"""/"""
return self._getval() / other
def __floordiv__(self, other):
"""//"""
return self._getval() // other
def __divmod__(self, other):
"""divmod"""
return divmod(self._getval(), other)
def __mod__(self, other):
"""%"""
return self._getval() % other
def __mul__(self, other):
"""*"""
return self._getval() * other
def __pow__(self, other):
"""**"""
return self._getval() ** other
def __gt__(self, other):
""">"""
return self._getval() > other
def __ge__(self, other):
""">="""
return self._getval() >= other
def __le__(self, other):
"""<="""
return self._getval() <= other
def __lt__(self, other):
"""<"""
return self._getval() < other
def __eq__(self, other):
"""=="""
return self._getval() == other
def __ne__(self, other):
"""!="""
return self._getval() != other
def __radd__(self, other):
"""+ (right)"""
return other + self._getval()
def __rtruediv__(self, other):
"""/ (right)"""
return other / self._getval()
def __rdivmod__(self, other):
"""divmod (right)"""
return divmod(other, self._getval())
def __rfloordiv__(self, other):
"""// (right)"""
return other // self._getval()
def __rmod__(self, other):
"""% (right)"""
return other % self._getval()
def __rmul__(self, other):
"""* (right)"""
return other * self._getval()
def __rpow__(self, other):
"""** (right)"""
return other ** self._getval()
def __rsub__(self, other):
"""- (right)"""
return other - self._getval()
def _log_value(value: float):
if not np.isfinite(value):
return value
if value == 1:
value += 1e-10
return np.log(value)
# A reexp for ONLY matching scientific
_match_scientific = re.compile(r"[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)")
def _sanatize_parameter_list(li: list) -> list:
for i, value in enumerate(li):
if isinstance(value, str) and _match_scientific.match(value):
li[i] = float(value)
return li
def _retrieve_from_list_by_type(li: list, t: type, default: Any):
tmp = list(filter(lambda x: isinstance(x, t), li))
if not tmp:
return default
li.remove(tmp[0])
return tmp[0]