"""The parameter group class"""
from __future__ import annotations
import pathlib
from copy import copy
from typing import Callable
from typing import Generator
import asteval
import numpy as np
import pandas as pd
import yaml
from .parameter import Parameter
[docs]class ParameterNotFoundException(Exception):
"""Raised when a Parameter is not found in the Group."""
def __init__(self, path, label):
super().__init__(f"Cannot find parameter {'.'.join(path)}.{label}")
[docs]class ParameterGroup(dict):
def __init__(self, label: str = None, root_group: ParameterGroup = None):
"""Represents are group of parameters. Can contain other groups, creating a
tree-like hierarchy.
Parameters
----------
label :
The label of the group.
"""
if label is not None and not Parameter.valid_label(label):
raise ValueError(f"'{label}' is not a valid group label.")
self._label = label
self._parameters = {}
self._root_group = root_group
self._evaluator = (
asteval.Interpreter(symtable=asteval.make_symbol_table(group=self))
if root_group is None
else None
)
super().__init__()
[docs] @classmethod
def from_dict(
cls,
parameter_dict: dict[str, dict | list],
label: str = None,
root_group: ParameterGroup = None,
) -> ParameterGroup:
"""Creates a :class:`ParameterGroup` from a dictionary.
Parameters
----------
parameter_dict :
A parameter dictionary containing parameters.
label :
The label of root group.
root_group:
The root group
"""
root = cls(label=label, root_group=root_group)
for label, item in parameter_dict.items():
label = str(label)
if isinstance(item, dict):
root.add_group(cls.from_dict(item, label=label, root_group=root))
if isinstance(item, list):
root.add_group(cls.from_list(item, label=label, root_group=root))
if root_group is None:
root.update_parameter_expression()
return root
[docs] @classmethod
def from_list(
cls,
parameter_list: list[float | list],
label: str = None,
root_group: ParameterGroup = None,
) -> ParameterGroup:
"""Creates a :class:`ParameterGroup` from a list.
Parameters
----------
parameter_list :
A parameter list containing parameters
label :
The label of the root group.
root_group:
The root group
"""
root = cls(label=label, root_group=root_group)
# get defaults
defaults = None
for item in parameter_list:
if isinstance(item, dict):
defaults = item
break
for i, item in enumerate(parameter_list):
if isinstance(item, (str, int, float)):
try:
item = float(item)
except Exception:
pass
if isinstance(item, (float, int, list)):
root.add_parameter(
Parameter.from_list_or_value(item, label=str(i + 1), default_options=defaults)
)
if root_group is None:
root.update_parameter_expression()
return root
[docs] @classmethod
def from_file(cls, filepath: str, fmt: str = None):
if fmt is None:
path = pathlib.Path(filepath)
fmt = path.suffix[1:] if path.suffix != "" else "yml"
if fmt not in cls.known_formats():
raise Exception(
f"Unknown parameter format '{format}'. "
f"Valid Formats are {cls.known_formats().keys()}."
)
return cls.known_formats()[fmt](filepath)
[docs] @classmethod
def from_yaml_file(cls, filepath: str) -> ParameterGroup:
"""Creates a :class:`ParameterGroup` from a YAML file.
Parameters
----------
filepath :
The path to the YAML file.
"""
with open(filepath) as f:
cls = cls.from_yaml(f)
return cls
[docs] @classmethod
def from_yaml(cls, yaml_string: str) -> ParameterGroup:
"""Creates a :class:`ParameterGroup` from a YAML string.
Parameters
----------
yaml_string :
The YAML string with the parameters.
"""
items = yaml.safe_load(yaml_string)
if isinstance(items, list):
return cls.from_list(items)
else:
return cls.from_dict(items)
[docs] @classmethod
def from_dataframe(cls, df: pd.DataFrame, source: str = "DataFrame") -> ParameterGroup:
"""Creates a :class:`ParameterGroup` from a :class:`pandas.DataFrame`"""
for column_name in ["label", "value"]:
if column_name not in df:
raise ValueError(f"Missing column '{column_name}' in '{source}'")
for column_name in ["minimum", "maximum", "value"]:
if column_name in df and any(not np.isreal(v) for v in df[column_name]):
raise ValueError(f"Column '{column_name}' in '{source}' has non numeric values")
for column_name in ["non-negative", "vary"]:
if column_name in df and any(not isinstance(v, bool) for v in df[column_name]):
raise ValueError(f"Column '{column_name}' in '{source}' has non boolean values")
root = cls()
for i, full_label in enumerate(df["label"]):
path = full_label.split(".")
group = root
while len(path) > 1:
group_label = path.pop(0)
if group_label not in group:
group.add_group(ParameterGroup(label=group_label, root_group=group))
group = group[group_label]
label = path.pop()
value = df["value"][i]
minimum = df["minimum"][i] if "minimum" in df else -np.inf
maximum = df["maximum"][i] if "maximum" in df else np.inf
non_negative = df["non-negative"][i] if "non-negative" in df else False
vary = df["vary"][i] if "vary" in df else True
expression = (
df["expression"][i]
if "expression" in df and isinstance(df["expression"][i], str)
else None
)
parameter = Parameter(
label=label,
full_label=full_label,
value=value,
expression=expression,
maximum=maximum,
minimum=minimum,
non_negative=non_negative,
vary=vary,
)
group.add_parameter(parameter)
root.update_parameter_expression()
return root
[docs] @classmethod
def from_csv(cls, filepath: str, delimiter: str = None) -> ParameterGroup:
"""Creates a :class:`ParameterGroup` from a CSV file.
Parameters
----------
filepath :
The path to the CSV file.
delimiter :
The delimiter of the CSV file.
"""
df = pd.read_csv(
filepath, delimiter=delimiter, skipinitialspace=True, na_values=["None", "none"]
)
return cls.from_dataframe(df, source=filepath)
@property
def label(self) -> str:
"""Label of the group."""
return self._label
@property
def root_group(self) -> ParameterGroup:
"""Root of the group."""
return self._root_group
[docs] def to_dataframe(self) -> pd.DataFrame:
parameter_dict = {
"label": [],
"value": [],
"minimum": [],
"maximum": [],
"vary": [],
"non-negative": [],
"expression": [],
}
for label, parameter in self.all():
parameter_dict["label"].append(label)
parameter_dict["value"].append(parameter.value)
parameter_dict["minimum"].append(parameter.minimum)
parameter_dict["maximum"].append(parameter.maximum)
parameter_dict["vary"].append(parameter.vary)
parameter_dict["non-negative"].append(parameter.non_negative)
parameter_dict["expression"].append(parameter.expression)
return pd.DataFrame(parameter_dict)
[docs] def to_csv(self, filename: str, delimiter: str = ","):
"""Writes a :class:`ParameterGroup` to a CSV file.
Parameters
----------
filepath :
The path to the CSV file.
delimiter : str
The delimiter of the CSV file.
"""
self.to_dataframe().to_csv(filename, sep=delimiter, na_rep="None", index=False)
[docs] def add_parameter(self, parameter: Parameter | list[Parameter]):
"""Adds a :class:`Parameter` to the group.
Parameters
----------
parameter :
The parameter to add.
"""
if not isinstance(parameter, list):
parameter = [parameter]
if any(not isinstance(p, Parameter) for p in parameter):
raise TypeError("Parameter must be instance of glotaran.model.Parameter")
for p in parameter:
p.index = len(self._parameters) + 1
if p.label is None:
p.label = f"{p.index}"
p.full_label = f"{self.label}.{p.label}" if self.label else p.label
self._parameters[p.label] = p
[docs] def add_group(self, group: ParameterGroup):
"""Adds a :class:`ParameterGroup` to the group.
Parameters
----------
group :
The group to add.
"""
if not isinstance(group, ParameterGroup):
raise TypeError("Group must be glotaran.model.ParameterGroup")
self[group.label] = group
[docs] def get_nr_roots(self) -> int:
"""Returns the number of roots of the group."""
n = 0
root = self.root_group
while root is not None:
n += 1
root = root.root_group
return n
[docs] def groups(self) -> Generator[ParameterGroup, None, None]:
"""Returns a generator over all groups and their subgroups."""
for group in self:
yield from group.groups()
[docs] def has(self, label: str) -> bool:
"""Checks if a parameter with the given label is in the group or in a subgroup.
Parameters
----------
label :
The label of the parameter.
"""
try:
self.get(label)
return True
except Exception:
return False
[docs] def get(self, label: str) -> Parameter:
"""Gets a :class:`Parameter` by its label.
Parameters
----------
label :
The label of the parameter.
"""
# sometimes the spec parser delivers the labels as int
label = str(label)
path = label.split(".")
label = path.pop()
# TODO: audit this code
group = self
for element in path:
try:
group = group[element]
except KeyError:
raise ParameterNotFoundException(path, label)
try:
return group._parameters[label]
except KeyError:
raise ParameterNotFoundException(path, label)
[docs] def copy(self) -> ParameterGroup:
root = ParameterGroup(label=self.label, root_group=self.root_group)
for label, parameter in self._parameters.items():
root._parameters[label] = copy(parameter)
for label, group in self.items():
root[label] = group.copy()
return root
[docs] def all(
self, root: str = None, separator: str = "."
) -> Generator[tuple[str, Parameter], None, None]:
"""Returns a generator over all parameter in the group and it's subgroups together with
their labels.
Parameters
----------
root :
The label of the root group
separator:
The separator for the parameter labels.
"""
root = f"{root}{self.label}{separator}" if root is not None else ""
for label, p in self._parameters.items():
yield (f"{root}{label}", p)
for _, l in self.items():
yield from l.all(root=root, separator=separator)
[docs] def get_label_value_and_bounds_arrays(
self, exclude_non_vary: bool = False
) -> tuple[list[str], np.ndarray, np.ndarray, np.ndarray]:
"""Returns a arrays of all parameter labels, values and bounds.
Parameters
----------
exclude_non_vary: bool = False
If true, parameters with `vary=False` are excluded.
"""
self.update_parameter_expression()
labels = []
values = []
lower_bounds = []
upper_bounds = []
for label, parameter in self.all():
if not exclude_non_vary or parameter.vary:
labels.append(label)
value, minimum, maximum = parameter.get_value_and_bounds_for_optimization()
values.append(value)
lower_bounds.append(minimum)
upper_bounds.append(maximum)
return labels, np.asarray(values), np.asarray(lower_bounds), np.asarray(upper_bounds)
[docs] def set_from_label_and_value_arrays(self, labels: list[str], values: np.ndarray):
"""Updates the parameter values from a list of labels and values."""
if len(labels) != len(values):
raise ValueError(
f"Length of labels({len(labels)}) not equal to length of values({len(values)})."
)
for label, value in zip(labels, values):
self.get(label).set_value_from_optimization(value)
self.update_parameter_expression()
[docs] def update_parameter_expression(self):
"""Updates all parameters which have an expression."""
for label, parameter in self.all():
if parameter.expression is not None:
value = self._evaluator(parameter.transformed_expression)
if not isinstance(value, (int, float)):
raise ValueError(
f"Expression '{parameter.expression}' of parameter '{label}' evaluates to"
f"non numeric value '{value}'."
)
parameter.value = value
[docs] def markdown(self) -> str:
"""Formats the :class:`ParameterGroup` as markdown string."""
t = "".join(" " for _ in range(self.get_nr_roots()))
s = ""
if self.label != "p":
s += f"{t}* __{self.label}__:\n"
for _, p in self._parameters.items():
s += f"{t} * {p}\n"
for _, g in self.items():
s += f"{g.__str__()}"
return s
def __repr__(self):
return self.markdown()
def __str__(self):
return self.__repr__()