Source code for glotaran.optimization.optimization_group

"""Module containing the optimization group class."""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

from glotaran.io.prepare_dataset import add_svd_to_dataset
from glotaran.model import DatasetGroup
from glotaran.model.dataset_model import finalize_dataset_model
from glotaran.optimization.data_provider import DataProvider
from glotaran.optimization.data_provider import DataProviderLinked
from glotaran.optimization.estimation_provider import EstimationProvider
from glotaran.optimization.estimation_provider import EstimationProviderLinked
from glotaran.optimization.estimation_provider import EstimationProviderUnlinked
from glotaran.optimization.matrix_provider import MatrixProvider
from glotaran.optimization.matrix_provider import MatrixProviderLinked
from glotaran.optimization.matrix_provider import MatrixProviderUnlinked
from glotaran.parameter import Parameters
from glotaran.project import Scheme

if TYPE_CHECKING:
    from glotaran.typing.types import ArrayLike


[docs] class OptimizationGroup: """A class to optimize a dataset group.""" def __init__( self, scheme: Scheme, dataset_group: DatasetGroup, ): """Initialize an optimization group for a dataset group. Parameters ---------- scheme : Scheme The optimization scheme. dataset_group : DatasetGroup The dataset group. """ self._dataset_group = dataset_group self._dataset_group.set_parameters(scheme.parameters) self._data = scheme.data self._add_svd = scheme.add_svd link_clp = dataset_group.link_clp if link_clp is None: link_clp = dataset_group.is_linkable(scheme.parameters, scheme.data) if link_clp: data_provider = DataProviderLinked(scheme, dataset_group) matrix_provider = MatrixProviderLinked(dataset_group, data_provider) estimation_provider = EstimationProviderLinked( dataset_group, data_provider, matrix_provider ) else: data_provider = DataProvider(scheme, dataset_group) # type:ignore[assignment] matrix_provider = MatrixProviderUnlinked( # type:ignore[assignment] self._dataset_group, data_provider ) estimation_provider = EstimationProviderUnlinked( # type:ignore[assignment] dataset_group, data_provider, matrix_provider # type:ignore[arg-type] ) self._data_provider: DataProvider = data_provider self._matrix_provider: MatrixProvider = matrix_provider self._estimation_provider: EstimationProvider = estimation_provider if self._add_svd: for dataset in self._data.values(): self.add_svd_data( "data", dataset, dataset.data.dims[0], dataset.data.dims[1], )
[docs] def calculate(self, parameters: Parameters): """Calculate the optimization group data. Parameters ---------- parameters : Parameters The parameters. """ self._dataset_group.set_parameters(parameters) self._matrix_provider.calculate() self._estimation_provider.estimate()
[docs] def get_additional_penalties(self) -> list[float]: """Get additional penalties. Returns ------- list[float] The additional penalties. """ return self._estimation_provider.get_additional_penalties()
[docs] def get_full_penalty(self) -> ArrayLike: """Get the full penalty. Returns ------- ArrayLike The full penalty. """ return self._estimation_provider.get_full_penalty()
[docs] def add_weight_to_result_data(self, dataset_label: str, result_dataset: xr.Dataset): """Add weight to result dataset if dataset is weighted. Parameters ---------- dataset_label : str The label of the data. result_dataset : xr.Dataset The label of the data. """ weight = self._data_provider.get_weight(dataset_label) if weight is None: return result_dataset["weighted_residual"] = result_dataset["residual"] result_dataset["residual"] = result_dataset["residual"] / weight if "weight" not in result_dataset: if weight.shape != result_dataset.data.shape: weight = weight.T result_dataset["weight"] = (result_dataset.data.dims, weight)
[docs] def create_result_data(self) -> dict[str, xr.Dataset]: """Create resulting datasets. Returns ------- dict[str, xr.Dataset] The datasets with the results. """ result_datasets = { label: data.copy() for label, data in self._data.items() if label in self._dataset_group.dataset_models.keys() } global_matrices, matrices = self._matrix_provider.get_result() clps, residuals = self._estimation_provider.get_result() for label, dataset_model in self._dataset_group.dataset_models.items(): result_dataset = result_datasets[label] model_dimension = self._data_provider.get_model_dimension(label) result_dataset.attrs["model_dimension"] = model_dimension global_dimension = self._data_provider.get_global_dimension(label) result_dataset.attrs["global_dimension"] = global_dimension result_dataset["residual"] = residuals[label] self.add_weight_to_result_data(label, result_dataset) result_dataset["matrix"] = matrices[label] if label in global_matrices: result_dataset["global_matrix"] = global_matrices[label] result_dataset["clp"] = clps[label] if self._add_svd: self.add_svd_data("residual", result_dataset, model_dimension, global_dimension) if "weighted_residual" in result_dataset: self.add_svd_data( "weighted_residual", result_dataset, model_dimension, global_dimension ) # Calculate RMS size = result_dataset.residual.shape[0] * result_dataset.residual.shape[1] result_dataset.attrs["root_mean_square_error"] = np.sqrt( (result_dataset.residual**2).sum() / size ).data result_dataset.attrs["weighted_root_mean_square_error"] = ( np.sqrt((result_dataset.weighted_residual**2).sum() / size).data if "weighted_residual" in result_dataset else result_dataset.attrs["root_mean_square_error"] ) result_dataset.attrs["dataset_scale"] = ( 1 if dataset_model.scale is None else dataset_model.scale.value # type:ignore[union-attr] ) # reconstruct fitted data result_dataset["fitted_data"] = result_dataset.data - result_dataset.residual finalize_dataset_model(dataset_model, result_dataset) return result_datasets
[docs] @staticmethod def add_svd_data(name: str, dataset: xr.Dataset, lsv_dim: str, rsv_dim: str): """Add the SVD of a data matrix to a dataset. Parameters ---------- name : str Name of the data matrix. dataset : xr.Dataset Dataset containing the data, which will be updated with the SVD values. lsv_dim : str The dimension name of the left singular vectors. rsv_dim : str The dimension name of the right singular vectors. """ add_svd_to_dataset( dataset, name=name, lsv_dim=lsv_dim, rsv_dim=rsv_dim, data_array=dataset[name] )
@property def number_of_clps(self) -> int: """Return number of conditionally linear parameters. Returns ------- int """ return self._matrix_provider.number_of_clps