Source code for glotaran.builtin.models.kinetic_spectrum.spectral_penalties

"""This package contains compartment constraint items."""

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import List
from typing import Tuple

import numpy as np
import xarray as xr

from glotaran.model import model_attribute
from glotaran.parameter import Parameter

if TYPE_CHECKING:
    from typing import Any
    from typing import Sequence

    from glotaran.parameter import ParameterGroup

    from .kinetic_spectrum_model import KineticSpectrumModel


[docs]@model_attribute( properties={ "source": str, "source_intervals": List[Tuple[float, float]], "target": str, "target_intervals": List[Tuple[float, float]], "parameter": Parameter, "weight": str, }, no_label=True, ) class EqualAreaPenalty: """An equal area constraint adds a the differenc of the sum of a compartments in the e matrix in one ore more intervals to the scaled sum of the e matrix of one or more target compartments to residual. The additional residual is scaled with the weight."""
[docs] def applies(self, index: Any) -> bool: """ Returns true if the index is in one of the intervals. Parameters ---------- index : Returns ------- applies : bool """ def applies(interval): return interval[0] <= index <= interval[1] if isinstance(self.interval, tuple): return applies(self.interval) return any([applies(i) for i in self.interval])
def has_spectral_penalties(model: KineticSpectrumModel) -> bool: return len(model.equal_area_penalties) != 0 def apply_spectral_penalties( model: KineticSpectrumModel, parameters: ParameterGroup, clp_labels: dict[str, list[str] | list[list[str]]], clps: dict[str, list[np.ndarray]], matrices: dict[str, np.ndarray | list[np.ndarray]], data: dict[str, xr.Dataset], group_tolerance: float, ) -> np.ndarray: penalties = [] for penalty in model.equal_area_penalties: penalty = penalty.fill(model, parameters) source_area = _get_area( model.index_dependent(), model.global_dimension, clp_labels, clps, data, group_tolerance, penalty.source_intervals, penalty.source, ) target_area = _get_area( model.index_dependent(), model.global_dimension, clp_labels, clps, data, group_tolerance, penalty.target_intervals, penalty.target, ) area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) penalties.append(area_penalty * penalty.weight) return np.asarray(penalties) def _get_area( index_dependent: bool, global_dimension: str, clp_labels: dict[str, list[list[str]]], clps: dict[str, list[np.ndarray]], data: dict[str, xr.Dataset], group_tolerance: float, intervals: list[tuple[float, float]], compartment: str, ) -> np.ndarray: area = [] area_indices = [] for label, dataset in data.items(): global_axis = dataset.coords[global_dimension] for interval in intervals: if interval[0] > global_axis[-1]: # interval not in this dataset continue start_idx, end_idx = _get_idx_from_interval(interval, global_axis) for i in range(start_idx, end_idx + 1): index_clp_labels = clp_labels[label][i] if index_dependent else clp_labels[label] if compartment in index_clp_labels: area.append(clps[label][i][index_clp_labels.index(compartment)]) area_indices.append(global_axis[i]) return np.asarray(area) # TODO: normalize for distance on global axis def _get_idx_from_interval( interval: tuple[float, float], axis: Sequence[float] | np.ndarray ) -> tuple[int, int]: """Retrieves start and end index of an interval on some axis Parameters ---------- interval : A tuple of floats with begin and end of the interval axis : Array like object which can be cast to np.array Returns ------- start, end : tuple of int """ axis_array = np.array(axis) start = np.abs(axis_array - interval[0]).argmin() if not np.isinf(interval[0]) else 0 end = ( np.abs(axis_array - interval[1]).argmin() if not np.isinf(interval[1]) else axis_array.size - 1 ) return start, end