"""The parameter group class."""
from __future__ import annotations
from copy import copy
from textwrap import indent
from typing import TYPE_CHECKING
from typing import Any
from typing import Generator
import asteval
import numpy as np
import pandas as pd
from tabulate import tabulate
from glotaran.deprecation import deprecate
from glotaran.io import load_parameters
from glotaran.io import save_parameters
from glotaran.parameter.parameter import Parameter
from glotaran.utils.ipython import MarkdownStr
if TYPE_CHECKING:
from glotaran.parameter.parameter_history import ParameterHistory
[docs]class ParameterNotFoundException(Exception):
"""Raised when a Parameter is not found in the Group."""
def __init__(self, path, label): # noqa: D107
super().__init__(f"Cannot find parameter {'.'.join(path+[label])!r}")
[docs]class ParameterGroup(dict):
"""Represents are group of parameters.
Can contain other groups, creating a tree-like hierarchy.
"""
loader = load_parameters
def __init__(self, label: str = None, root_group: ParameterGroup = None):
"""Initialize a :class:`ParameterGroup` instance with ``label``.
Parameters
----------
label : str
The label of the group.
root_group : ParameterGroup
The root group
Raises
------
ValueError
Raised if the an invalid label is given.
"""
if label is not None and not Parameter.valid_label(label):
raise ValueError(f"'{label}' is not a valid group label.")
self._label = label
self._parameters: dict[str, Parameter] = {}
self._root_group = root_group
self._evaluator = (
asteval.Interpreter(symtable=asteval.make_symbol_table(group=self))
if root_group is None
else None
)
self.source_path = "parameters.csv"
super().__init__()
[docs] @classmethod
def from_dict(
cls,
parameter_dict: dict[str, dict[str, Any] | list[float | list[Any]]],
label: str = None,
root_group: ParameterGroup = None,
) -> ParameterGroup:
"""Create a :class:`ParameterGroup` from a dictionary.
Parameters
----------
parameter_dict : dict[str, dict | list]
A parameter dictionary containing parameters.
label : str
The label of the group.
root_group : ParameterGroup
The root group
Returns
-------
ParameterGroup
The created :class:`ParameterGroup`
"""
root = cls(label=label, root_group=root_group)
for label, item in parameter_dict.items():
label = str(label)
if isinstance(item, dict):
root.add_group(cls.from_dict(item, label=label, root_group=root))
if isinstance(item, list):
root.add_group(cls.from_list(item, label=label, root_group=root))
if root_group is None:
root.update_parameter_expression()
return root
[docs] @classmethod
def from_list(
cls,
parameter_list: list[float | list[Any]],
label: str = None,
root_group: ParameterGroup = None,
) -> ParameterGroup:
"""Create a :class:`ParameterGroup` from a list.
Parameters
----------
parameter_list : list[float | list[Any]]
A parameter list containing parameters
label : str
The label of the group.
root_group : ParameterGroup
The root group
Returns
-------
ParameterGroup
The created :class:`ParameterGroup`.
"""
root = cls(label=label, root_group=root_group)
# get defaults
defaults = None
for item in parameter_list:
if isinstance(item, dict):
defaults = item
break
for i, item in enumerate(parameter_list):
if isinstance(item, (str, int, float)):
try:
item = float(item)
except Exception:
pass
if isinstance(item, (float, int, list)):
root.add_parameter(
Parameter.from_list_or_value(item, label=str(i + 1), default_options=defaults)
)
if root_group is None:
root.update_parameter_expression()
return root
[docs] @classmethod
def from_parameter_dict_list(cls, parameter_dict_list: list[dict[str, Any]]) -> ParameterGroup:
"""Create a :class:`ParameterGroup` from a list of parameter dictionaries.
Parameters
----------
parameter_dict_list : list[dict[str, Any]]
A list of parameter dictionaries.
Returns
-------
ParameterGroup
The created :class:`ParameterGroup`.
"""
parameter_group = cls()
for parameter_dict in parameter_dict_list:
group = parameter_group.get_group_for_parameter_by_label(
parameter_dict["label"], create_if_not_exist=True
)
group.add_parameter(Parameter.from_dict(parameter_dict))
parameter_group.update_parameter_expression()
return parameter_group
[docs] @classmethod
def from_dataframe(cls, df: pd.DataFrame, source: str = "DataFrame") -> ParameterGroup:
"""Create a :class:`ParameterGroup` from a :class:`pandas.DataFrame`.
Parameters
----------
df : pd.DataFrame
The source data frame.
source : str
Optional name of the source file, used for error messages.
Returns
-------
ParameterGroup
The created parameter group.
Raises
------
ValueError
Raised if the columns 'label' or 'value' doesn't exist. Also raised if the columns
'minimum', 'maximum' or 'values' contain non numeric values or if the columns
'non-negative' or 'vary' are no boolean.
"""
for column_name in ["label", "value"]:
if column_name not in df:
raise ValueError(f"Missing column '{column_name}' in '{source}'")
for column_name in ["minimum", "maximum", "value"]:
if column_name in df and any(not np.isreal(v) for v in df[column_name]):
raise ValueError(f"Column '{column_name}' in '{source}' has non numeric values")
for column_name in ["non-negative", "vary"]:
if column_name in df and any(not isinstance(v, bool) for v in df[column_name]):
raise ValueError(f"Column '{column_name}' in '{source}' has non boolean values")
# clean NaN if expressions
if "expression" in df:
expressions = df["expression"].to_list()
df["expression"] = [expr if isinstance(expr, str) else None for expr in expressions]
return cls.from_parameter_dict_list(df.to_dict(orient="records"))
@property
def label(self) -> str | None:
"""Label of the group.
Returns
-------
str
The label of the group.
"""
return self._label
@property
def root_group(self) -> ParameterGroup | None:
"""Root of the group.
Returns
-------
ParameterGroup
The root group.
"""
return self._root_group
[docs] def to_parameter_dict_list(self, as_optimized: bool = True) -> list[dict[str, Any]]:
"""Create list of parameter dictionaries from the group.
Parameters
----------
as_optimized : bool
Whether to include properties which are the result of optimization.
Returns
-------
list[dict[str, Any]]
Alist of parameter dictionaries.
"""
return [p[1].as_dict(as_optimized=as_optimized) for p in self.all()]
[docs] def to_dataframe(self, as_optimized: bool = True) -> pd.DataFrame:
"""Create a pandas data frame from the group.
Parameters
----------
as_optimized : bool
Whether to include properties which are the result of optimization.
Returns
-------
pd.DataFrame
The created data frame.
"""
return pd.DataFrame(self.to_parameter_dict_list(as_optimized=as_optimized))
[docs] def get_group_for_parameter_by_label(
self, parameter_label: str, create_if_not_exist: bool = False
) -> ParameterGroup:
"""Get the group for a parameter by it's label.
Parameters
----------
parameter_label : str
The parameter label.
create_if_not_exist : bool
Create the parameter group if not existent.
Returns
-------
ParameterGroup
The group of the parameter.
Raises
------
KeyError
Raised if the group does not exist and `create_if_not_exist` is `False`.
"""
path = parameter_label.split(".")
group = self
while len(path) > 1:
group_label = path.pop(0)
if group_label not in group:
if create_if_not_exist:
group.add_group(ParameterGroup(label=group_label, root_group=group))
else:
raise KeyError(f"Subgroup '{group_label}' does not exist.")
group = group[group_label]
return group
[docs] @deprecate(
deprecated_qual_name_usage=(
"glotaran.parameter.ParameterGroup.to_csv(file_name=<parameters.csv>)"
),
new_qual_name_usage=(
"glotaran.io.save_parameters(parameters, "
'file_name=<parameters.csv>, format_name="csv")'
),
to_be_removed_in_version="0.7.0",
importable_indices=(2, 1),
)
def to_csv(self, filename: str, delimiter: str = ",") -> None:
"""Save a :class:`ParameterGroup` to a CSV file.
Warning
-------
Deprecated use
``glotaran.io.save_parameters(parameters, file_name=<parameters.csv>, format_name="csv")``
instead.
Parameters
----------
filename : str
File to write the parameter specs to.
delimiter : str
Character to separate columns., by default ","
"""
save_parameters(self, file_name=filename, allow_overwrite=True, sep=delimiter)
[docs] def add_parameter(self, parameter: Parameter | list[Parameter]):
"""Add a :class:`Parameter` to the group.
Parameters
----------
parameter : Parameter | list[Parameter]
The parameter to add.
Raises
------
TypeError
If ``parameter`` or any item of it is not an instance of :class:`Parameter`.
"""
if not isinstance(parameter, list):
parameter = [parameter]
if any(not isinstance(p, Parameter) for p in parameter):
raise TypeError("Parameter must be instance of glotaran.parameter.Parameter")
for p in parameter:
p.index = len(self._parameters) + 1
if p.label is None:
p.label = f"{p.index}"
p.full_label = f"{self.label}.{p.label}" if self.label else p.label
self._parameters[p.label] = p
[docs] def add_group(self, group: ParameterGroup):
"""Add a :class:`ParameterGroup` to the group.
Parameters
----------
group : ParameterGroup
The group to add.
Raises
------
TypeError
Raised if the group is not an instance of :class:`ParameterGroup`.
"""
if not isinstance(group, ParameterGroup):
raise TypeError("Group must be glotaran.parameter.ParameterGroup")
self[group.label] = group
[docs] def get_nr_roots(self) -> int:
"""Return the number of roots of the group.
Returns
-------
int
The number of roots.
"""
n = 0
root = self.root_group
while root is not None:
n += 1
root = root.root_group
return n
[docs] def groups(self) -> Generator[ParameterGroup, None, None]:
"""Return a generator over all groups and their subgroups.
Yields
------
ParameterGroup
A subgroup of :class:`ParameterGroup`.
"""
for group in self:
yield from group.groups()
[docs] def has(self, label: str) -> bool:
"""Check if a parameter with the given label is in the group or in a subgroup.
Parameters
----------
label : str
The label of the parameter, with its path in a :class:`ParameterGroup` prepended.
Returns
-------
bool
Whether a parameter with the given label exists in the group.
"""
try:
self.get(label)
return True
except Exception:
return False
[docs] def get(self, label: str) -> Parameter: # type:ignore[override]
"""Get a :class:`Parameter` by its label.
Parameters
----------
label : str
The label of the parameter, with its path in a :class:`ParameterGroup` prepended.
Returns
-------
Parameter
The parameter.
Raises
------
ParameterNotFoundException
Raised if no parameter with the given label exists.
"""
# sometimes the spec parser delivers the labels as int
label = str(label)
path = label.split(".")
label = path.pop()
# TODO: audit this code
group = self
for element in path:
try:
group = group[element]
except KeyError:
raise ParameterNotFoundException(path, label)
try:
return group._parameters[label]
except KeyError:
raise ParameterNotFoundException(path, label)
[docs] def copy(self) -> ParameterGroup:
"""Create a copy of the :class:`ParameterGroup`.
Returns
-------
ParameterGroup :
A copy of the :class:`ParameterGroup`.
"""
root = ParameterGroup(label=self.label, root_group=self.root_group)
for label, parameter in self._parameters.items():
root._parameters[label] = copy(parameter)
for label, group in self.items():
root[label] = group.copy()
return root
[docs] def all(
self, root: str | None = None, separator: str = "."
) -> Generator[tuple[str, Parameter], None, None]:
"""Iterate over all parameter in the group and it's subgroups together with their labels.
Parameters
----------
root : str
The label of the root group
separator : str
The separator for the parameter labels.
Yields
------
tuple[str, Parameter]
A tuple containing the full label of the parameter and the parameter itself.
"""
root = f"{root}{self.label}{separator}" if root is not None else ""
for label, p in self._parameters.items():
yield (f"{root}{label}", p)
for _, l in self.items():
yield from l.all(root=root, separator=separator)
[docs] def get_label_value_and_bounds_arrays(
self, exclude_non_vary: bool = False
) -> tuple[list[str], np.ndarray, np.ndarray, np.ndarray]:
"""Return a arrays of all parameter labels, values and bounds.
Parameters
----------
exclude_non_vary: bool
If true, parameters with `vary=False` are excluded.
Returns
-------
tuple[list[str], np.ndarray, np.ndarray, np.ndarray]
A tuple containing a list of parameter labels and
an array of the values, lower and upper bounds.
"""
self.update_parameter_expression()
labels = []
values = []
lower_bounds = []
upper_bounds = []
for label, parameter in self.all():
if not exclude_non_vary or parameter.vary:
labels.append(label)
value, minimum, maximum = parameter.get_value_and_bounds_for_optimization()
values.append(value)
lower_bounds.append(minimum)
upper_bounds.append(maximum)
return labels, np.asarray(values), np.asarray(lower_bounds), np.asarray(upper_bounds)
[docs] def set_from_label_and_value_arrays(self, labels: list[str], values: np.ndarray):
"""Update the parameter values from a list of labels and values.
Parameters
----------
labels : list[str]
A list of parameter labels.
values : np.ndarray
An array of parameter values.
Raises
------
ValueError
Raised if the size of the labels does not match the stize of values.
"""
if len(labels) != len(values):
raise ValueError(
f"Length of labels({len(labels)}) not equal to length of values({len(values)})."
)
for label, value in zip(labels, values):
self.get(label).set_value_from_optimization(value)
self.update_parameter_expression()
[docs] def set_from_history(self, history: ParameterHistory, index: int):
"""Update the :class:`ParameterGroup` with values from a parameter history.
Parameters
----------
history : ParameterHistory
The parameter history.
index : int
The history index.
"""
self.set_from_label_and_value_arrays(
history.parameter_labels, history.get_parameters(index)
)
[docs] def update_parameter_expression(self):
"""Update all parameters which have an expression.
Raises
------
ValueError
Raised if an expression evaluates to a non-numeric value.
"""
for label, parameter in self.all():
if parameter.expression is not None:
value = self._evaluator(parameter.transformed_expression)
if not isinstance(value, (int, float)):
raise ValueError(
f"Expression '{parameter.expression}' of parameter '{label}' evaluates to "
f"non numeric value '{value}'."
)
parameter.value = value
[docs] def markdown(self, float_format: str = ".3e") -> MarkdownStr:
"""Format the :class:`ParameterGroup` as markdown string.
This is done by recursing the nested :class:`ParameterGroup` tree.
Parameters
----------
float_format: str
Format string for floating point numbers, by default ".3e"
Returns
-------
MarkdownStr :
The markdown representation as string.
"""
node_indentation = " " * self.get_nr_roots()
return_string = ""
table_header = [
"_Label_",
"_Value_",
"_Standard Error_",
"_Minimum_",
"_Maximum_",
"_Vary_",
"_Non-Negative_",
"_Expression_",
]
if self.label is not None:
return_string += f"{node_indentation}* __{self.label}__:\n"
if len(self._parameters):
parameter_rows = [
[
parameter.label,
parameter.value,
parameter.standard_error,
parameter.minimum,
parameter.maximum,
parameter.vary,
parameter.non_negative,
f"`{parameter.expression}`",
]
for _, parameter in self._parameters.items()
]
parameter_table = indent(
tabulate(
parameter_rows,
headers=table_header,
tablefmt="github",
missingval="None",
floatfmt=float_format,
),
f" {node_indentation}",
)
return_string += f"\n{parameter_table}\n\n"
for _, child_group in sorted(self.items()):
return_string += f"{child_group.markdown(float_format=float_format)}"
return MarkdownStr(return_string)
def _repr_markdown_(self) -> str:
"""Create a markdown respresentation.
Special method used by ``ipython`` to render markdown.
Returns
-------
str :
The markdown representation as string.
"""
return str(self.markdown())
def __repr__(self) -> str:
"""Representation used by repl and tracebacks.
Returns
-------
str :
A string representation of the :class:`ParameterGroup`.
"""
parameter_short_notations = [
[str(parameter.label), parameter.value] for parameter in self._parameters.values()
]
if self.label is None:
if len(self._parameters) == 0:
return f"{type(self).__name__}.from_dict({super().__repr__()})"
else:
return f"{type(self).__name__}.from_list({parameter_short_notations})"
if len(self._parameters):
return parameter_short_notations.__repr__()
else:
return super().__repr__()
def __str__(self) -> str:
"""Representation used by print and str."""
return str(self.markdown())