Source code for glotaran.model.dataset_model

"""The DatasetModel class."""
from __future__ import annotations

from collections import Counter
from typing import TYPE_CHECKING

import numpy as np
import xarray as xr

from glotaran.model.item import model_item
from glotaran.model.item import model_item_validator

if TYPE_CHECKING:
    from typing import Any
    from typing import Generator
    from typing import Hashable

    from glotaran.model.megacomplex import Megacomplex
    from glotaran.model.model import Model
    from glotaran.parameter import Parameter


[docs]def create_dataset_model_type(properties: dict[str, Any]) -> type[DatasetModel]: @model_item(properties=properties) class ModelDatasetModel(DatasetModel): pass return ModelDatasetModel
[docs]class DatasetModel: """A `DatasetModel` describes a dataset in terms of a glotaran model. It contains references to model items which describe the physical model for a given dataset. A general dataset descriptor assigns one or more megacomplexes and a scale parameter. """
[docs] def iterate_megacomplexes( self, ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: """Iterates of der dataset model's megacomplexes.""" for i, megacomplex in enumerate(self.megacomplex): scale = self.megacomplex_scale[i] if self.megacomplex_scale is not None else None yield scale, megacomplex
[docs] def iterate_global_megacomplexes( self, ) -> Generator[tuple[Parameter | int | None, Megacomplex | str], None, None]: """Iterates of der dataset model's global megacomplexes.""" for i, megacomplex in enumerate(self.global_megacomplex): scale = ( self.global_megacomplex_scale[i] if self.global_megacomplex_scale is not None else None ) yield scale, megacomplex
[docs] def get_model_dimension(self) -> str: """Returns the dataset model's model dimension.""" if not hasattr(self, "_model_dimension"): if len(self.megacomplex) == 0: raise ValueError(f"No megacomplex set for dataset model '{self.label}'") if isinstance(self.megacomplex[0], str): raise ValueError(f"Dataset model '{self.label}' was not filled") self._model_dimension = self.megacomplex[0].dimension if any(self._model_dimension != m.dimension for m in self.megacomplex): raise ValueError( f"Megacomplex dimensions do not match for dataset model '{self.label}'." ) return self._model_dimension
[docs] def finalize_data(self, dataset: xr.Dataset) -> None: is_full_model = self.has_global_model() for megacomplex in self.megacomplex: megacomplex.finalize_data(self, dataset, is_full_model=is_full_model) if is_full_model: for megacomplex in self.global_megacomplex: megacomplex.finalize_data( self, dataset, is_full_model=is_full_model, as_global=True )
[docs] def overwrite_model_dimension(self, model_dimension: str) -> None: """Overwrites the dataset model's model dimension.""" self._model_dimension = model_dimension
# TODO: make explicit we only support 2 dimensions at present # TODO: the global dimension should become a flexible index (MultiIndex) # the user can then specify the name of the MultiIndex global dimension # using the function overwrite_global_dimension # e.g. in FLIM, x, y dimension may get 'flattened' to a MultiIndex 'pixel'
[docs] def get_global_dimension(self) -> str: """Returns the dataset model's global dimension.""" if not hasattr(self, "_global_dimension"): if self.has_global_model(): if isinstance(self.global_megacomplex[0], str): raise ValueError(f"Dataset model '{self.label}' was not filled") self._global_dimension = self.global_megacomplex[0].dimension if any(self._global_dimension != m.dimension for m in self.global_megacomplex): raise ValueError( "Global megacomplex dimensions do not " f"match for dataset model '{self.label}'." ) elif hasattr(self, "_coords"): return next(dim for dim in self._coords if dim != self.get_model_dimension()) else: if not hasattr(self, "_data"): raise ValueError(f"Data not set for dataset model '{self.label}'") self._global_dimension = next( dim for dim in self._data.data.dims if dim != self.get_model_dimension() ) return self._global_dimension
[docs] def overwrite_global_dimension(self, global_dimension: str) -> None: """Overwrites the dataset model's global dimension.""" self._global_dimension = global_dimension
[docs] def swap_dimensions(self) -> None: """Swaps the dataset model's global and model dimension.""" global_dimension = self.get_model_dimension() model_dimension = self.get_global_dimension() self.overwrite_global_dimension(global_dimension) self.overwrite_model_dimension(model_dimension)
[docs] def set_data(self, dataset: xr.Dataset) -> DatasetModel: """Sets the dataset model's data.""" self._coords = {name: dim.values for name, dim in dataset.coords.items()} self._data: np.ndarray = dataset.data.values self._weight: np.ndarray | None = dataset.weight.values if "weight" in dataset else None if self._weight is not None: self._data = self._data * self._weight return self
[docs] def get_data(self) -> np.ndarray: """Gets the dataset model's data.""" return self._data
[docs] def get_weight(self) -> np.ndarray | None: """Gets the dataset model's weight.""" return self._weight
[docs] def is_index_dependent(self) -> bool: """Indicates if the dataset model is index dependent.""" if hasattr(self, "_index_dependent"): return self._index_dependent return any(m.index_dependent(self) for m in self.megacomplex)
[docs] def overwrite_index_dependent(self, index_dependent: bool): """Overrides the index dependency of the dataset""" self._index_dependent = index_dependent
[docs] def has_global_model(self) -> bool: """Indicates if the dataset model can model the global dimension.""" return self.global_megacomplex is not None and len(self.global_megacomplex) != 0
[docs] def set_coordinates(self, coords: dict[str, np.ndarray]): """Sets the dataset model's coordinates.""" self._coords = coords
[docs] def get_coordinates(self) -> dict[Hashable, np.ndarray]: """Gets the dataset model's coordinates.""" return self._coords
[docs] def get_model_axis(self) -> np.ndarray: """Gets the dataset model's model axis.""" return self._coords[self.get_model_dimension()]
[docs] def get_global_axis(self) -> np.ndarray: """Gets the dataset model's global axis.""" return self._coords[self.get_global_dimension()]
[docs] @model_item_validator(False) def ensure_unique_megacomplexes(self, model: Model) -> list[str]: """Ensure that unique megacomplexes Are only used once per dataset. Parameters ---------- model : Model Model object using this dataset model. Returns ------- list[str] Error messages to be shown when the model gets validated. """ glotaran_unique_megacomplex_types = [] for megacomplex_name in self.megacomplex: try: megacomplex_instance = model.megacomplex[megacomplex_name] if type(megacomplex_instance).glotaran_unique() is True: type_name = megacomplex_instance.type or megacomplex_instance.name glotaran_unique_megacomplex_types.append(type_name) except KeyError: pass return [ f"Multiple instances of unique megacomplex type {type_name!r} " f"in dataset {self.label!r}" for type_name, count in Counter(glotaran_unique_megacomplex_types).most_common() if count > 1 ]