from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
from typing import TypeVar
import numpy as np
import xarray as xr
from glotaran.analysis.nnls import residual_nnls
from glotaran.analysis.optimization_group_calculator import OptimizationGroupCalculator
from glotaran.analysis.optimization_group_calculator_linked import (
OptimizationGroupCalculatorLinked,
)
from glotaran.analysis.optimization_group_calculator_unlinked import (
OptimizationGroupCalculatorUnlinked,
)
from glotaran.analysis.util import get_min_max_from_interval
from glotaran.analysis.variable_projection import residual_variable_projection
from glotaran.io.prepare_dataset import add_svd_to_dataset
from glotaran.model import DatasetGroup
from glotaran.model import DatasetModel
from glotaran.model import Model
from glotaran.parameter import ParameterGroup
from glotaran.parameter import ParameterHistory
from glotaran.project import Scheme
if TYPE_CHECKING:
from typing import Hashable
[docs]class InitialParameterError(ValueError):
def __init__(self):
super().__init__("Initial parameters can not be evaluated.")
[docs]class ParameterNotInitializedError(ValueError):
def __init__(self):
super().__init__("Parameter not initialized")
XrDataContainer = TypeVar("XrDataContainer", xr.DataArray, xr.Dataset)
residual_functions = {
"variable_projection": residual_variable_projection,
"non_negative_least_squares": residual_nnls,
}
[docs]class OptimizationGroup:
def __init__(
self,
scheme: Scheme,
dataset_group: DatasetGroup,
):
"""Create OptimizationGroup instance from a scheme (:class:`glotaran.analysis.scheme.Scheme`)
Args:
scheme (Scheme): An instance of :class:`glotaran.analysis.scheme.Scheme`
which defines your model, parameters, and data
"""
self._model = scheme.model
if scheme.parameters is None:
raise ParameterNotInitializedError
self._parameters = scheme.parameters.copy()
self._dataset_group_model = dataset_group.model
self._clp_link_tolerance = scheme.clp_link_tolerance
try:
self._residual_function = residual_functions[dataset_group.model.residual_function]
except KeyError:
raise ValueError(
f"Unknown residual function '{dataset_group.model.residual_function}', "
f"allowed functions are: {list(residual_functions.keys())}."
)
self._dataset_models = dataset_group.dataset_models
self._overwrite_index_dependent = self.model.need_index_dependent()
self._model.validate(raise_exception=True)
self._prepare_data(scheme, list(dataset_group.dataset_models.keys()))
self._dataset_labels = list(self.data.keys())
link_clp = dataset_group.model.link_clp
if link_clp is None:
link_clp = self.model.is_groupable(self.parameters, self.data)
self._calculator: OptimizationGroupCalculator = (
OptimizationGroupCalculatorLinked(self)
if link_clp
else OptimizationGroupCalculatorUnlinked(self)
)
# all of the above are always not None
self._matrices = None
self._reduced_matrices = None
self._reduced_clps = None
self._clps = None
self._weighted_residuals = None
self._residuals = None
self._additional_penalty = None
self._full_penalty = None
@property
def model(self) -> Model:
"""Property providing access to the used model
The model is a subclass of :class:`glotaran.model.Model` decorated with the `@model`
decorator :class:`glotaran.model.model_decorator.model`
For an example implementation see e.g. :class:`glotaran.builtin.models.kinetic_spectrum`
Returns:
Model: A subclass of :class:`glotaran.model.Model`
The model must be decorated with the `@model` decorator
:class:`glotaran.model.model_decorator.model`
"""
return self._model
@property
def data(self) -> dict[str, xr.Dataset]:
return self._data
@property
def parameters(self) -> ParameterGroup:
return self._parameters
@parameters.setter
def parameters(self, parameters: ParameterGroup):
self._parameters = parameters
self.reset()
@property
def dataset_models(self) -> dict[str, DatasetModel]:
return self._dataset_models
@property
def matrices(
self,
) -> dict[str, np.ndarray | list[np.ndarray]]:
if self._matrices is None:
self._calculator.calculate_matrices()
return self._matrices
@property
def reduced_matrices(
self,
) -> dict[str, np.ndarray] | dict[str, list[np.ndarray]] | list[np.ndarray]:
if self._reduced_matrices is None:
self._calculator.calculate_matrices()
return self._reduced_matrices
@property
def reduced_clps(
self,
) -> dict[str, list[np.ndarray]]:
if self._reduced_clps is None:
self._calculator.calculate_residual()
return self._reduced_clps
@property
def clps(
self,
) -> dict[str, list[np.ndarray]]:
if self._clps is None:
self._calculator.calculate_residual()
return self._clps
@property
def weighted_residuals(
self,
) -> dict[str, list[np.ndarray]]:
if self._weighted_residuals is None:
self._calculator.calculate_residual()
return self._weighted_residuals
@property
def residuals(
self,
) -> dict[str, list[np.ndarray]]:
if self._residuals is None:
self._calculator.calculate_residual()
return self._residuals
@property
def additional_penalty(
self,
) -> dict[str, list[float]]:
if self._additional_penalty is None:
self._calculator.calculate_residual()
return self._additional_penalty
@property
def full_penalty(self) -> np.ndarray:
if self._full_penalty is None:
self._calculator.calculate_full_penalty()
return self._full_penalty
@property
def cost(self) -> float:
return 0.5 * np.dot(self.full_penalty, self.full_penalty)
[docs] def reset(self):
"""Resets all results and `DatasetModels`. Use after updating parameters."""
self._dataset_models = {
label: dataset_model.fill(self._model, self._parameters).set_data(self.data[label])
for label, dataset_model in self.model.dataset.items()
if label in self._dataset_labels
}
if self._overwrite_index_dependent:
for d in self._dataset_models.values():
d.overwrite_index_dependent(self._overwrite_index_dependent)
self._reset_results()
def _reset_results(self):
self._matrices = None
self._reduced_matrices = None
self._reduced_clps = None
self._clps = None
self._weighted_residuals = None
self._residuals = None
self._additional_penalty = None
self._full_penalty = None
def _prepare_data(self, scheme: Scheme, labels: list[str]):
self._data = {}
self._dataset_models = {}
for label, dataset in scheme.data.items():
if label not in labels:
continue
if isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset(name="data")
dataset_model = self._model.dataset[label]
dataset_model = dataset_model.fill(self.model, self.parameters)
dataset_model.set_data(dataset)
if self._overwrite_index_dependent:
dataset_model.overwrite_index_dependent(self._overwrite_index_dependent)
self._dataset_models[label] = dataset_model
global_dimension = dataset_model.get_global_dimension()
model_dimension = dataset_model.get_model_dimension()
dataset = self._transpose_dataset(
dataset, ordered_dims=[model_dimension, global_dimension]
)
if scheme.add_svd:
add_svd_to_dataset(dataset, lsv_dim=model_dimension, rsv_dim=global_dimension)
self._add_weight(label, dataset)
self._data[label] = dataset
def _transpose_dataset(
self, datacontainer: XrDataContainer, ordered_dims: list[Hashable]
) -> XrDataContainer:
"""Reorder dimension of the datacontainer with the order provided by ``ordered_dims``.
Parameters
----------
dataset: XrDataContainer
Dataset to be reordered
ordered_dims: list[Hashable]
Order the dimensions should be in.
Returns
-------
XrDataContainer
Datacontainer with reordered dimensions.
"""
ordered_dims = list(filter(lambda dim: dim in datacontainer.dims, ordered_dims))
ordered_dims += list(filter(lambda dim: dim not in ordered_dims, datacontainer.dims))
return datacontainer.transpose(*ordered_dims)
def _add_weight(self, label, dataset):
# if the user supplies a weight we ignore modeled weights
if "weight" in dataset:
if any(label in weight.datasets for weight in self.model.weights):
warnings.warn(
f"Ignoring model weight for dataset '{label}'"
" because weight is already supplied by dataset."
)
return
dataset_model = self.dataset_models[label]
dataset_model.set_data(dataset)
global_dimension = dataset_model.get_global_dimension()
model_dimension = dataset_model.get_model_dimension()
global_axis = dataset.coords[global_dimension]
model_axis = dataset.coords[model_dimension]
for weight in self.model.weights:
if label in weight.datasets:
if "weight" not in dataset:
dataset["weight"] = xr.DataArray(
np.ones_like(dataset.data), coords=dataset.data.coords
)
idx = {}
if weight.global_interval is not None:
idx[global_dimension] = get_min_max_from_interval(
weight.global_interval, global_axis
)
if weight.model_interval is not None:
idx[model_dimension] = get_min_max_from_interval(
weight.model_interval, model_axis
)
dataset.weight[idx] *= weight.value
[docs] def create_result_data(
self,
parameter_history: ParameterHistory = None,
copy: bool = True,
success: bool = True,
add_svd: bool = True,
) -> dict[str, xr.Dataset]:
if not success:
if parameter_history is not None and parameter_history.number_of_records > 1:
self.parameters.set_from_history(parameter_history, -2)
else:
raise InitialParameterError()
self.reset()
self._calculator.prepare_result_creation()
result_data = {}
for label, dataset_model in self.dataset_models.items():
result_data[label] = self.create_result_dataset(label, copy=copy)
dataset_model.finalize_data(result_data[label])
return result_data
[docs] def create_result_dataset(
self, label: str, copy: bool = True, add_svd: bool = True
) -> xr.Dataset:
dataset = self.data[label]
dataset_model = self.dataset_models[label]
global_dimension = dataset_model.get_global_dimension()
model_dimension = dataset_model.get_model_dimension()
if copy:
dataset = dataset.copy()
if dataset_model.is_index_dependent():
dataset = self._calculator.create_index_dependent_result_dataset(label, dataset)
else:
dataset = self._calculator.create_index_independent_result_dataset(label, dataset)
# TODO: adapt tests to handle add_svd=False
if add_svd:
self._create_svd("weighted_residual", dataset, model_dimension, global_dimension)
self._create_svd("residual", dataset, model_dimension, global_dimension)
# Calculate RMS
size = dataset.residual.shape[0] * dataset.residual.shape[1]
dataset.attrs["root_mean_square_error"] = np.sqrt(
(dataset.residual ** 2).sum() / size
).values
size = dataset.weighted_residual.shape[0] * dataset.weighted_residual.shape[1]
dataset.attrs["weighted_root_mean_square_error"] = np.sqrt(
(dataset.weighted_residual ** 2).sum() / size
).values
if dataset_model.scale is not None:
dataset.attrs["dataset_scale"] = dataset_model.scale.value
else:
dataset.attrs["dataset_scale"] = 1
# reconstruct fitted data
dataset["fitted_data"] = dataset.data - dataset.residual
return dataset
def _create_svd(self, name: str, dataset: xr.Dataset, lsv_dim: str, rsv_dim: str):
"""Calculate the SVD of a data matrix in the dataset and add it to the dataset.
Parameters
----------
name : str
Name of the data matrix.
dataset : xr.Dataset
Dataset containing the data, which will be updated with the SVD values.
"""
data_array: xr.DataArray = self._transpose_dataset(
dataset[name],
ordered_dims=[lsv_dim, rsv_dim],
)
add_svd_to_dataset(
dataset, name=name, lsv_dim=lsv_dim, rsv_dim=rsv_dim, data_array=data_array
)