Source code for glotaran.model.dataset_group

"""This module contains the dataset group."""

from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING
from typing import Literal

import xarray as xr
from attrs import define
from attrs import field

from glotaran.model.dataset_model import DatasetModel
from glotaran.model.dataset_model import get_dataset_model_model_dimension
from glotaran.model.dataset_model import has_dataset_model_global_model
from glotaran.model.item import ModelItem
from glotaran.model.item import fill_item
from glotaran.model.item import item

if TYPE_CHECKING:
    from glotaran.model.model import Model
    from glotaran.parameter import Parameters


[docs] @item class DatasetGroupModel(ModelItem): """A group of datasets which will evaluated independently.""" residual_function: Literal["variable_projection", "non_negative_least_squares"] = ( "variable_projection" ) """The residual function to use.""" link_clp: bool | None = None """Whether to link the clp parameter."""
[docs] @define class DatasetGroup: """A dataset group for optimization.""" residual_function: Literal["variable_projection", "non_negative_least_squares"] """The residual function to use.""" link_clp: bool | None """Whether to link the clp parameter.""" model: Model parameters: Parameters | None = None dataset_models: dict[str, DatasetModel] = field(factory=dict)
[docs] def set_parameters(self, parameters: Parameters): """Set the group parameters. Parameters ---------- parameters : Parameters The parameters. """ self.parameters = parameters for label in self.dataset_models: self.dataset_models[label] = fill_item( self.model.dataset[label], self.model, parameters )
[docs] def is_linkable(self, parameters: Parameters, data: Mapping[str, xr.Dataset]) -> bool: """Check if the group is linkable. Parameters ---------- parameters : Parameters A parameter set parameters. data : Mapping[str, xr.Dataset] A the data to link. Returns ------- bool """ if any(has_dataset_model_global_model(d) for d in self.dataset_models.values()): return False dataset_models = [ fill_item(self.model.dataset[label], self.model, parameters) for label in self.dataset_models ] model_dimensions = {get_dataset_model_model_dimension(d) for d in dataset_models} if len(model_dimensions) != 1: return False global_dimensions = set() for dataset in data.values(): global_dimensions |= { dim for dim in dataset.data.coords if dim not in model_dimensions } return len(global_dimensions) == 1