Source code for glotaran.optimization.data_provider

"""Module containing the data provider classes."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

from glotaran.model import DatasetGroup
from glotaran.model import Model
from glotaran.model.dataset_model import get_dataset_model_model_dimension
from glotaran.model.dataset_model import has_dataset_model_global_model
from glotaran.project import Scheme

if TYPE_CHECKING:
    from typing import Literal

    from glotaran.typing.types import ArrayLike


[docs] class AlignDatasetError(ValueError): """Indicates that datasets can not be aligned.""" def __init__(self): """Initialize a AlignDatasetError.""" super().__init__( "Cannot link datasets, aligning is ambiguous. \n\n" "Try to lower link tolerance or change the alignment method." )
[docs] class DataProvider: """A class to provide prepared data for optimization.""" def __init__(self, scheme: Scheme, dataset_group: DatasetGroup): """Initialize a data provider for a scheme and a dataset_group. Parameters ---------- scheme : Scheme The optimization scheme. dataset_group : DatasetGroup The dataset group. """ self._data: dict[str, ArrayLike] = {} self._weight: dict[str, ArrayLike | None] = {} self._flattened_data: dict[str, ArrayLike] = {} self._flattened_weight: dict[str, ArrayLike | None] = {} self._model_axes: dict[str, ArrayLike] = {} self._model_dimensions: dict[str, str] = {} self._global_axes: dict[str, ArrayLike] = {} self._global_dimensions: dict[str, str] = {} for label, dataset_model in dataset_group.dataset_models.items(): dataset = scheme.data[label] model_dimension = get_dataset_model_model_dimension(dataset_model) self._model_axes[label] = dataset.coords[model_dimension].data self._model_dimensions[label] = model_dimension global_dimension = self.infer_global_dimension(model_dimension, dataset.data.dims) self._global_axes[label] = dataset.coords[global_dimension].data self._global_dimensions[label] = global_dimension self._weight[label] = self.get_from_dataset( dataset, "weight", model_dimension, global_dimension ) self.add_model_weight(scheme.model, label, model_dimension, global_dimension) self._data[label] = self.get_from_dataset( # type:ignore[assignment] dataset, "data", model_dimension, global_dimension ) if self._weight[label] is not None: self._data[label] *= self._weight[label] if has_dataset_model_global_model(dataset_model): self._flattened_data[label] = self._data[label].T.flatten() self._flattened_weight[label] = ( self._weight[label].T.flatten() # type:ignore[union-attr] if self._weight[label] is not None else None )
[docs] @staticmethod def infer_global_dimension(model_dimension: str, dimensions: tuple[str]) -> str: """Infer the name of the global dimension from tuple of dimensions. Parameters ---------- model_dimension : str The model dimension. dimensions : tuple[str] The dimensions tuple to infer from. Returns ------- str The inferred name of the global dimension. """ return next(dim for dim in dimensions if dim != model_dimension)
[docs] @staticmethod def get_from_dataset( dataset: xr.Dataset, name: str, model_dimension: str, global_dimension: str ) -> ArrayLike | None: """Get a copy of data from a dataset with dimensions (model, global). Parameters ---------- dataset : xr.Dataset The dataset to retrieve from. name : str The name of the data to retrieve. model_dimension : str The model dimension. global_dimension : str The global dimension. Returns ------- ArrayLike | None The copy of the data. None if name is not present in dataset. """ data = None if name in dataset: data = dataset[name].data.copy() if dataset[name].dims != (model_dimension, global_dimension): data = data.T return data
[docs] @staticmethod def get_axis_slice_from_interval(interval: tuple[float, float], axis: ArrayLike) -> slice: """Get a slice of indices from a min max tuple and for an axis. Parameters ---------- interval : tuple[float, float] The min max tuple. axis : ArrayLike The axis to slice. Returns ------- slice The slice of indices. """ interval_min = interval[0] interval_max = interval[1] if interval_min > interval_max: interval_min, interval_max = interval_max, interval_min minimum = 0 if np.isinf(interval_min) else np.abs(axis - interval_min).argmin() maximum = ( axis.size - 1 if np.isinf(interval_max) else np.abs(axis - interval_max).argmin() + 1 ) return slice(minimum, maximum)
[docs] def add_model_weight( self, model: Model, dataset_label: str, model_dimension: str, global_dimension: str, ): """Add model weight to data. Parameters ---------- model : Model The model. dataset_label : str The label of the data. model_dimension : str The model dimension. global_dimension : str The global dimension. """ model_weights = [weight for weight in model.weights if dataset_label in weight.datasets] if not model_weights: return if self._weight[dataset_label]: warnings.warn( f"Ignoring model weight for dataset '{dataset_label}'" " because weight is already supplied by dataset." ) return model_axis = self._model_axes[dataset_label] global_axis = self._global_axes[dataset_label] weight = xr.DataArray( np.ones((model_axis.size, global_axis.size)), coords=( (model_dimension, model_axis), (global_dimension, global_axis), ), ) for model_weight in model_weights: idx = {} if model_weight.global_interval is not None: idx[global_dimension] = self.get_axis_slice_from_interval( model_weight.global_interval, global_axis ) if model_weight.model_interval is not None: idx[model_dimension] = self.get_axis_slice_from_interval( model_weight.model_interval, model_axis ) weight[idx] *= model_weight.value self._weight[dataset_label] = weight.data
[docs] def get_data(self, dataset_label: str) -> ArrayLike: """Get data for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- ArrayLike The data. """ return self._data[dataset_label]
[docs] def get_weight(self, dataset_label: str) -> ArrayLike | None: """Get weight for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- ArrayLike | None The weight. """ return self._weight[dataset_label]
[docs] def get_flattened_data(self, dataset_label: str) -> ArrayLike: """Get flattened data for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- ArrayLike The flattened data. """ return self._flattened_data[dataset_label]
[docs] def get_flattened_weight(self, dataset_label: str) -> ArrayLike | None: """Get flattened weight for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- ArrayLike | None The flattened weight. """ return self._flattened_weight[dataset_label]
[docs] def get_model_axis(self, dataset_label: str) -> ArrayLike: """Get the model axis for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- ArrayLike The model axis. """ return self._model_axes[dataset_label]
[docs] def get_model_dimension(self, dataset_label: str) -> str: """Get the model dimension for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- str The model dimension. """ return self._model_dimensions[dataset_label]
[docs] def get_global_axis(self, dataset_label: str) -> ArrayLike: """Get the global axis for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- ArrayLike The global axis. """ return self._global_axes[dataset_label]
[docs] def get_global_dimension(self, dataset_label: str) -> str: """Get the global dimension for a dataset. Parameters ---------- dataset_label : str The label of the data. Returns ------- str The global dimension. """ return self._global_dimensions[dataset_label]
[docs] class DataProviderLinked(DataProvider): """A class to provide aligned data for optimization.""" def __init__( self, scheme: Scheme, dataset_group: DatasetGroup, ): """Initialize a linked data provider for a scheme and a dataset_group. Parameters ---------- scheme : Scheme The optimization scheme. dataset_group : DatasetGroup The dataset group. """ super().__init__(scheme, dataset_group) aligned_global_axes = self.create_aligned_global_axes(scheme) self._aligned_global_axis, self._aligned_data = self.align_data(aligned_global_axes) self._aligned_dataset_indices = self.align_dataset_indices(aligned_global_axes) self._aligned_group_labels, self._group_definitions = self.align_groups( aligned_global_axes ) self._aligned_weights = self.align_weights(aligned_global_axes)
[docs] @staticmethod def align_index( index: int, target_axis: ArrayLike, tolerance: float, method: Literal["nearest", "backward", "forward"], ) -> int: """Align an index on a target axis. Parameters ---------- index : int The index to align. target_axis : ArrayLike The axis to align the index on. tolerance : float The alignment tolerance. method : Literal["nearest", "backward", "forward"] The alignment method. Returns ------- int The aligned index. """ diff = target_axis - index if method == "forward": diff = diff[diff >= 0] elif method == "backward": diff = diff[diff <= 0] diff = np.abs(diff) if len(diff) > 0 and diff.min() <= tolerance: index = target_axis[diff.argmin()] return index
@property def aligned_global_axis(self) -> ArrayLike: """Get the aligned global axis for the dataset group. Returns ------- ArrayLike The aligned global axis. """ return self._aligned_global_axis @property def group_definitions(self) -> dict[str, list[str]]: """Get the group definitions for the dataset group. Returns ------- dict[str, list[str]] The group definitions. """ return self._group_definitions
[docs] def get_aligned_group_label(self, index: int) -> str: """Get the group label for an index. Parameters ---------- index : int The index on the aligned global axis. Returns ------- str The aligned group label. """ return self._aligned_group_labels[index]
[docs] def get_aligned_dataset_indices(self, index: int) -> ArrayLike: """Get the aligned dataset indices for an index. Parameters ---------- index : int The index on the aligned global axis. Returns ------- ArrayLike The aligned dataset indices. """ return self._aligned_dataset_indices[index]
[docs] def get_aligned_data(self, index: int) -> ArrayLike: """Get the aligned data for an index. Parameters ---------- index : int The index on the aligned global axis. Returns ------- ArrayLike The aligned data. """ return self._aligned_data[index]
[docs] def get_aligned_weight(self, index: int) -> ArrayLike | None: """Get the aligned weight for an index. Parameters ---------- index : int The index on the aligned global axis. Returns ------- ArrayLike | None The aligned weight. """ return self._aligned_weights[index]
[docs] def create_aligned_global_axes(self, scheme: Scheme) -> dict[str, ArrayLike]: """Create aligned global axes for the dataset group. Parameters ---------- scheme : Scheme The optimization scheme. Returns ------- dict[str, ArrayLike] The aligned global axes. Raises ------ AlignDatasetError Raised when dataset alignment is ambiguous. """ aligned_axis_values = None aligned_global_axes = {} for label, global_axis in self._global_axes.items(): aligned_global_axis = global_axis if aligned_axis_values is None: aligned_axis_values = aligned_global_axis else: aligned_global_axis = [ self.align_index( index, aligned_axis_values, scheme.clp_link_tolerance, scheme.clp_link_method, ) for index in aligned_global_axis ] if len(np.unique(aligned_global_axis)) != len(aligned_global_axis): raise AlignDatasetError() aligned_axis_values = np.unique( np.concatenate([aligned_axis_values, aligned_global_axis]) ) aligned_global_axes[label] = aligned_global_axis return aligned_global_axes
[docs] def align_data( self, aligned_global_axes: dict[str, ArrayLike] ) -> tuple[ArrayLike, list[ArrayLike]]: """Align the data in a dataset group. Parameters ---------- aligned_global_axes : dict[str, ArrayLike] The aligned global axes. Returns ------- tuple[ArrayLike, list[ArrayLike]] The aligned global axis and data. """ aligned_data = xr.concat( [ xr.DataArray( self.get_data(label), dims=["model", "global"], coords={"global": axis}, ) for label, axis in aligned_global_axes.items() ], dim="model", ) aligned_global_axis = aligned_data.coords["global"].data return ( aligned_global_axis, [ aligned_data.isel({"global": i}).dropna(dim="model").data for i in range(aligned_global_axis.size) ], )
[docs] def align_dataset_indices(self, aligned_global_axes: dict[str, ArrayLike]) -> list[ArrayLike]: """Align the global indices in a dataset group. Parameters ---------- aligned_global_axes : dict[str, ArrayLike] The aligned global axes. Returns ------- list[ArrayLike] The aligned dataset indices. """ aligned_indices = xr.concat( [ xr.DataArray( np.arange(len(axis), dtype=int), dims=["global"], coords={"global": axis}, ) for axis in aligned_global_axes.values() ], dim="dataset", ) return [ aligned_indices.isel({"global": i}).dropna(dim="dataset").data.astype(int) for i in range(self._aligned_global_axis.size) ]
[docs] @staticmethod def align_groups( aligned_global_axes: dict[str, ArrayLike] ) -> tuple[ArrayLike, dict[str, list[str]]]: """Align the groups in a dataset group. Parameters ---------- aligned_global_axes : dict[str, ArrayLike] The aligned global axes. Returns ------- tuple[ArrayLike, dict[str, list[str]]] The aligned grouplabels and group definitions. """ aligned_groups = xr.concat( [ xr.DataArray(np.full(len(axis), label), dims=["global"], coords={"global": axis}) for label, axis in aligned_global_axes.items() ], dim="dataset", fill_value="", ) # for every element along the global axis, concatenate all dataset labels # into an ndarray of shape (len(global,) # as an alternative to the more elegant xarray built-in which is limited to 32 datasets # aligned_group_labels = aligned_groups.str.join(dim="dataset").data aligned_group_labels = np.asarray( ["".join(sub_arr.values) for _, sub_arr in aligned_groups.groupby("global")] ) group_definitions: dict[str, list[str]] = {} for i, group_label in enumerate(aligned_group_labels): if group_label not in group_definitions: group_definitions[group_label] = list( filter(lambda label: label != "", aligned_groups.isel({"global": i}).data) ) return aligned_group_labels, group_definitions
[docs] def align_weights(self, aligned_global_axes: dict[str, ArrayLike]) -> list[ArrayLike | None]: """Align the weights in a dataset group. Parameters ---------- aligned_global_axes : dict[str, ArrayLike] The aligned global axes. Returns ------- list[ArrayLike | None] The aligned weights. """ all_weights = { label: xr.DataArray( weight, dims=["model", "global"], coords={"global": aligned_global_axes[label]}, ) for label, weight in self._weight.items() if weight is not None } aligned_weights: list[ArrayLike | None] = [None] * self._aligned_global_axis.size if all_weights: for i, group_label in enumerate(self._aligned_group_labels): group_dataset_labels = self._group_definitions[group_label] if any(label in all_weights for label in group_dataset_labels): index_weights = [] for label in group_dataset_labels: if label in all_weights: index_weights.append( all_weights[label] .sel({"global": self._aligned_global_axis[i]}) .data ) else: size = self.get_model_axis(label).size index_weights.append(np.ones(size)) aligned_weights[i] = np.concatenate(index_weights) return aligned_weights