Source code for glotaran.optimization.estimation_provider

"""Module containing the estimation provider classes."""
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

from glotaran.model import DatasetGroup
from glotaran.model import DatasetModel
from glotaran.model import EqualAreaPenalty
from glotaran.model.dataset_model import has_dataset_model_global_model
from glotaran.model.item import fill_item
from glotaran.optimization.data_provider import DataProvider
from glotaran.optimization.data_provider import DataProviderLinked
from glotaran.optimization.matrix_provider import MatrixProviderLinked
from glotaran.optimization.matrix_provider import MatrixProviderUnlinked
from glotaran.optimization.nnls import residual_nnls
from glotaran.optimization.variable_projection import residual_variable_projection

if TYPE_CHECKING:
    from glotaran.typing.types import ArrayLike
SUPPORTED_RESIUDAL_FUNCTIONS = {
    "variable_projection": residual_variable_projection,
    "non_negative_least_squares": residual_nnls,
}


[docs] class UnsupportedResidualFunctionError(ValueError): """Indicates that the residual function is unsupported.""" def __init__(self, residual_function: str): """Initialize an UnsupportedMethodError. Parameters ---------- residual_function : str The unsupported residual_function. """ super().__init__( f"Unknown residual function '{residual_function}', " f"supported functions are: {list(SUPPORTED_RESIUDAL_FUNCTIONS.keys())}." )
[docs] class EstimationProvider: """A class to provide estimation for optimization.""" def __init__(self, dataset_group: DatasetGroup): """Initialize an estimation provider for a dataset group. Parameters ---------- dataset_group : DatasetGroup The dataset group. Raises ------ UnsupportedResidualFunctionError Raised when residual function of the group dataset group is unsupported. """ self._group = dataset_group self._clp_penalty: list[float] = [] try: self._residual_function = SUPPORTED_RESIUDAL_FUNCTIONS[dataset_group.residual_function] except KeyError as e: raise UnsupportedResidualFunctionError(dataset_group.residual_function) from e @property def group(self) -> DatasetGroup: """Get the dataset group. Returns ------- DatasetGroup The dataset group. """ return self._group
[docs] def calculate_residual( self, matrix: ArrayLike, data: ArrayLike ) -> tuple[ArrayLike, ArrayLike]: """Calculate the clps and the residual for a matrix and data. Parameters ---------- matrix : ArrayLike The matrix. data : ArrayLike The data. Returns ------- tuple[ArrayLike, ArrayLike] The estimated clp and residual. """ return self._residual_function(matrix, data)
[docs] def retrieve_clps( self, clp_labels: list[str], reduced_clp_labels: list[str], reduced_clps: ArrayLike, index: int, ) -> ArrayLike: """Retrieve clp from reduced clp. Parameters ---------- clp_labels : list[str] The original clp labels. reduced_clp_labels : list[str] The reduced clp labels. reduced_clps : ArrayLike The reduced clps. index : int The index on the global axis. Returns ------- ArrayLike The retrieved clps. """ model = self.group.model parameters = self.group.parameters if len(model.clp_relations) == 0 and len(model.clp_constraints) == 0: return reduced_clps clps = np.zeros(len(clp_labels)) for i, label in enumerate(reduced_clp_labels): idx = clp_labels.index(label) clps[idx] = reduced_clps[i] for relation in model.clp_relations: relation = fill_item(relation, model, parameters) # type:ignore[arg-type] if ( relation.target in clp_labels and relation.applies(index) and relation.source in clp_labels ): source_idx = clp_labels.index(relation.source) target_idx = clp_labels.index(relation.target) clps[target_idx] = relation.parameter * clps[source_idx] return clps
[docs] def get_additional_penalties(self) -> list[float]: """Get the additional penalty. Returns ------- list[float] The additional penalty. """ return self._clp_penalty
[docs] def calculate_clp_penalties( self, clp_labels: list[list[str]], clps: list[np.ndarray], global_axis: np.ndarray, ) -> list[float]: """Calculate the clp penalty. Parameters ---------- clp_labels : list[list[str]] The clp labels. clps : list[ArrayLike] The clps. global_axis : ArrayLike The global axis. Returns ------- list[float] The clp penalty. """ model = self.group.model parameters = self.group.parameters penalties = [] for penalty in model.clp_penalties: if not isinstance(penalty, EqualAreaPenalty): continue penalty = fill_item(penalty, model, parameters) # type:ignore[arg-type] source_area = _get_area( penalty.source, clp_labels, clps, penalty.source_intervals, global_axis, ) target_area = _get_area( penalty.target, clp_labels, clps, penalty.target_intervals, global_axis, ) if len(target_area) == 0 and len(source_area) == 0: continue elif len(target_area) == 0: warnings.warn( "Ignoring equal area penalty, target clp " f"{penalty.target} not present." ) continue elif len(source_area) == 0: warnings.warn( "Ignoring equal area penalty, source clp " f"{penalty.source} not present." ) continue area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) penalties.append(area_penalty * penalty.weight) return penalties
[docs] def estimate(self): """Calculate the estimation. .. # noqa: DAR401 """ raise NotImplementedError
[docs] def get_full_penalty(self) -> ArrayLike: """Get the full penalty. Returns ------- ArrayLike The clp penalty. .. # noqa: DAR202 .. # noqa: DAR401 """ raise NotImplementedError
[docs] def get_result( self, ) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray],]: """Get the results of the estimation. Returns ------- tuple[dict[str, xr.DataArray], dict[str, xr.DataArray]] A tuple of the estimated clps and residuals. .. # noqa: DAR202 .. # noqa: DAR401 """ raise NotImplementedError
[docs] class EstimationProviderUnlinked(EstimationProvider): """A class to provide estimation for optimization of an unlinked dataset group.""" def __init__( self, dataset_group: DatasetGroup, data_provider: DataProvider, matrix_provider: MatrixProviderUnlinked, ): """Initialize an estimation provider for an unlinked dataset group. Parameters ---------- dataset_group : DatasetGroup The dataset group. data_provider : DataProvider The data provider. matrix_provider : MatrixProviderUnlinked The matrix provider. """ super().__init__(dataset_group) self._data_provider = data_provider self._matrix_provider = matrix_provider self._clps: dict[str, list[ArrayLike] | ArrayLike] = { label: [] for label in self.group.dataset_models } self._residuals: dict[str, list[ArrayLike] | ArrayLike] = { label: [] for label in self.group.dataset_models }
[docs] def estimate(self): """Calculate the estimation.""" self._clp_penalty.clear() for dataset_model in self.group.dataset_models.values(): if has_dataset_model_global_model(dataset_model): self.calculate_full_model_estimation(dataset_model) else: self.calculate_estimation(dataset_model)
[docs] def get_full_penalty(self) -> ArrayLike: """Get the full penalty. Returns ------- ArrayLike The clp penalty. """ full_penalty = np.concatenate( [ self._residuals[label] if has_dataset_model_global_model(dataset_model) else np.concatenate(self._residuals[label]) for label, dataset_model in self.group.dataset_models.items() ] ) if len(self._clp_penalty) != 0: full_penalty = np.concatenate([full_penalty, self._clp_penalty]) return full_penalty
[docs] def get_result( self, ) -> tuple[dict[str, list[xr.DataArray]], dict[str, list[xr.DataArray]],]: """Get the results of the estimation. Returns ------- tuple[dict[str, xr.DataArray], dict[str, xr.DataArray]] A tuple of the estimated clps and residuals. """ clps, residuals = {}, {} for label, dataset_model in self.group.dataset_models.items(): model_dimension = self._data_provider.get_model_dimension(label) model_axis = self._data_provider.get_model_axis(label) global_dimension = self._data_provider.get_global_dimension(label) global_axis = self._data_provider.get_global_axis(label) if has_dataset_model_global_model(dataset_model): residuals[label] = xr.DataArray( np.array(self._residuals[label]).T.reshape(model_axis.size, global_axis.size), coords={global_dimension: global_axis, model_dimension: model_axis}, dims=[model_dimension, global_dimension], ) clp_labels = self._matrix_provider.get_matrix_container(label).clp_labels global_clp_labels = self._matrix_provider.get_global_matrix_container( label ).clp_labels clps[label] = xr.DataArray( np.array(self._clps[label]).reshape((len(global_clp_labels), len(clp_labels))), coords={"global_clp_label": global_clp_labels, "clp_label": clp_labels}, dims=["global_clp_label", "clp_label"], ) else: residuals[label] = xr.DataArray( np.array(self._residuals[label]).T, coords={global_dimension: global_axis, model_dimension: model_axis}, dims=[model_dimension, global_dimension], ) clps[label] = xr.DataArray( self._clps[label], coords=( (global_dimension, global_axis), ( "clp_label", self._matrix_provider.get_matrix_container(label).clp_labels, ), ), ) return clps, residuals
[docs] def calculate_full_model_estimation(self, dataset_model: DatasetModel): """Calculate the estimation for a dataset with a full model. Parameters ---------- dataset_model : DatasetModel The dataset model. """ label = dataset_model.label full_matrix = self._matrix_provider.get_full_matrix(label) data = self._data_provider.get_flattened_data(label) self._clps[label], self._residuals[label] = self.calculate_residual(full_matrix, data)
[docs] def calculate_estimation(self, dataset_model: DatasetModel): """Calculate the estimation for a dataset. Parameters ---------- dataset_model : DatasetModel The dataset model. """ label = dataset_model.label self._clps[label].clear() # type:ignore[union-attr] self._residuals[label].clear() # type:ignore[union-attr] global_axis = self._data_provider.get_global_axis(label) data = self._data_provider.get_data(label) clp_labels = [] for index, global_index_value in enumerate(global_axis): matrix_container = self._matrix_provider.get_prepared_matrix_container(label, index) reduced_clps, residual = self.calculate_residual( matrix_container.matrix, data[:, index] ) clp_labels.append(self._matrix_provider.get_matrix_container(label).clp_labels) clp = self.retrieve_clps( clp_labels[index], matrix_container.clp_labels, reduced_clps, global_index_value ) self._clps[label].append(clp) # type:ignore[union-attr] self._residuals[label].append(residual) # type:ignore[union-attr] self._clp_penalty += self.calculate_clp_penalties( clp_labels, self._clps[label], global_axis # type:ignore[arg-type] )
[docs] class EstimationProviderLinked(EstimationProvider): """A class to provide estimation for optimization of a linked dataset group.""" def __init__( self, dataset_group: DatasetGroup, data_provider: DataProviderLinked, matrix_provider: MatrixProviderLinked, ): """Initialize an estimation provider for a linked dataset group. Parameters ---------- dataset_group : DatasetGroup The dataset group. data_provider : DataProviderLinked The data provider. matrix_provider : MatrixProviderLinked The matrix provider. """ super().__init__(dataset_group) self._data_provider = data_provider self._matrix_provider = matrix_provider self._clps: list[ArrayLike] = [ None # type:ignore[list-item] ] * self._data_provider.aligned_global_axis.size self._residuals: list[ArrayLike] = [ None # type:ignore[list-item] ] * self._data_provider.aligned_global_axis.size
[docs] def estimate(self): """Calculate the estimation.""" for index, global_index_value in enumerate(self._data_provider.aligned_global_axis): matrix_container = self._matrix_provider.get_aligned_matrix_container(index) data = self._data_provider.get_aligned_data(index) reduced_clps, residual = self.calculate_residual(matrix_container.matrix, data) self._clps[index] = self.retrieve_clps( self._matrix_provider.aligned_full_clp_labels[index], matrix_container.clp_labels, reduced_clps, global_index_value, ) self._residuals[index] = residual self._clp_penalty = self.calculate_clp_penalties( self._matrix_provider.aligned_full_clp_labels, self._clps, self._data_provider.aligned_global_axis, )
[docs] def get_full_penalty(self) -> ArrayLike: """Get the full penalty. Returns ------- ArrayLike The clp penalty. """ return np.concatenate((np.concatenate(self._residuals), self._clp_penalty))
[docs] def get_result( self, ) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray],]: """Get the results of the estimation. Returns ------- tuple[dict[str, xr.DataArray], dict[str, xr.DataArray]] A tuple of the estimated clps and residuals. """ clps: dict[str, xr.DataArray] = {} residuals: dict[str, xr.DataArray] = {} for dataset_label in self.group.dataset_models: dataset_clps, dataset_residual = [], [] for index in range(self._data_provider.aligned_global_axis.size): group_label = self._data_provider.get_aligned_group_label(index) if dataset_label not in group_label: continue group_datasets = self._data_provider.group_definitions[group_label] dataset_index = group_datasets.index(dataset_label) clp_labels = self._matrix_provider.get_matrix_container(dataset_label).clp_labels dataset_clps.append( xr.DataArray( [ self._clps[index][ self._matrix_provider.aligned_full_clp_labels[index].index(label) ] for label in clp_labels ], coords={"clp_label": clp_labels}, ) ) start = sum( self._data_provider.get_model_axis(label).size for label in group_datasets[:dataset_index] ) end = start + self._data_provider.get_model_axis(dataset_label).size dataset_residual.append(self._residuals[index][start:end]) model_dimension = self._data_provider.get_model_dimension(dataset_label) model_axis = self._data_provider.get_model_axis(dataset_label) global_dimension = self._data_provider.get_global_dimension(dataset_label) global_axis = self._data_provider.get_global_axis(dataset_label) clps[dataset_label] = xr.concat( dataset_clps, dim=global_dimension, ) clps[dataset_label].coords[global_dimension] = global_axis residuals[dataset_label] = xr.DataArray( np.array(dataset_residual).T, coords={global_dimension: global_axis, model_dimension: model_axis}, dims=[model_dimension, global_dimension], ) return clps, residuals
def _get_area( clp_label: str, clp_labels: list[str] | list[list[str]], clps: list[ArrayLike], intervals: list[tuple[float, float]], global_axis: ArrayLike, ) -> ArrayLike: """Get get slice of a clp on intervals on the global axis. Parameters ---------- clp_label : str The label of the clp. clp_labels: list[str] | list[list[str]] The clp labels. clps : list[ArrayLike] The clps. intervals: list[tuple[float, float]] The intervals on the global axis. global_axis : ArrayLike The global axis. Returns ------- ArrayLike: The concatenated slices. """ area = [] for interval in intervals: if interval[0] > global_axis[-1]: continue bounded_interval = ( max(interval[0], np.min(global_axis)), min(interval[1], np.max(global_axis)), ) interval_slice = DataProvider.get_axis_slice_from_interval(bounded_interval, global_axis) start_idx, end_idx = interval_slice.start, interval_slice.stop for i in range(start_idx, end_idx): index_clp_labels: list[str] = ( clp_labels[i] if isinstance(clp_labels[0], list) else clp_labels # type:ignore[assignment] ) if clp_label in index_clp_labels: area.append(clps[i][index_clp_labels.index(clp_label)]) return np.asarray(area) # TODO: normalize for distance on global axis