Source code for glotaran.builtin.megacomplexes.decay.util

from __future__ import annotations

from typing import TYPE_CHECKING

import numba as nb
import numpy as np
import xarray as xr

from glotaran.builtin.megacomplexes.decay.decay_matrix_gaussian_irf import (
    calculate_decay_matrix_gaussian_irf,
)
from glotaran.builtin.megacomplexes.decay.decay_matrix_gaussian_irf import (
    calculate_decay_matrix_gaussian_irf_on_index,
)
from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian
from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian
from glotaran.model import DatasetModel
from glotaran.model import Megacomplex
from glotaran.model import get_dataset_model_model_dimension

if TYPE_CHECKING:
    from glotaran.typing.types import ArrayLike


[docs] def index_dependent(dataset_model: DatasetModel) -> bool: """Determine if a dataset_model is index dependent. Parameters ---------- dataset_model : DatasetModel A dataset model instance. Returns ------- bool Returns True if the dataset_model has an IRF that is index dependent (e.g. has dispersion). """ return ( isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.is_index_dependent() )
[docs] def calculate_matrix( megacomplex: Megacomplex, dataset_model: DatasetModel, global_axis: ArrayLike, model_axis: ArrayLike, **kwargs, ): compartments = megacomplex.get_compartments(dataset_model) initial_concentration = megacomplex.get_initial_concentration(dataset_model) k_matrix = megacomplex.get_k_matrix() rates = k_matrix.rates(compartments, initial_concentration) # init the matrix matrix_shape = ( (global_axis.size, model_axis.size, rates.size) if index_dependent(dataset_model) else (model_axis.size, rates.size) ) matrix = np.zeros(matrix_shape, dtype=np.float64) if index_dependent(dataset_model): decay_matrix_implementation_index_dependent( matrix, rates, global_axis, model_axis, dataset_model ) else: decay_matrix_implementation_index_independent( matrix, rates, global_axis, model_axis, dataset_model ) if not np.all(np.isfinite(matrix)): raise ValueError( f"Non-finite concentrations for K-Matrix '{k_matrix.label}':\n" f"{k_matrix.matrix_as_markdown(fill_parameters=True)}" ) # apply A matrix matrix = matrix @ megacomplex.get_a_matrix(dataset_model) # done return compartments, matrix
[docs] def collect_megacomplexes(dataset_model: DatasetModel, as_global: bool) -> list[Megacomplex]: from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex from glotaran.builtin.megacomplexes.decay.decay_parallel_megacomplex import ( DecayParallelMegacomplex, ) from glotaran.builtin.megacomplexes.decay.decay_sequential_megacomplex import ( DecaySequentialMegacomplex, ) return list( filter( lambda m: isinstance( m, (DecayMegacomplex, DecayParallelMegacomplex, DecaySequentialMegacomplex) ), dataset_model.global_megacomplex if as_global else dataset_model.megacomplex, ) )
[docs] def finalize_data( dataset_model: DatasetModel, dataset: xr.Dataset, is_full_model: bool = False, as_global: bool = False, ): species_dimension = "decay_species" if as_global else "species" if species_dimension in dataset.coords: # The first decay megacomplescomplex called will finalize the data for all # decay megacomplexes. return decay_megacomplexes = collect_megacomplexes(dataset_model, as_global) global_dimension = dataset.attrs["global_dimension"] name = "images" if global_dimension == "pixel" else "spectra" all_species = [] for megacomplex in decay_megacomplexes: for species in megacomplex.get_compartments(dataset_model): if species not in all_species: all_species.append(species) retrieve_species_associated_data( dataset_model, dataset, all_species, species_dimension, global_dimension, name, is_full_model, as_global, ) retrieve_initial_concentration( dataset_model, dataset, species_dimension, ) retrieve_irf(dataset_model, dataset, global_dimension) if not is_full_model: for megacomplex in decay_megacomplexes: retrieve_decay_associated_data( megacomplex, dataset_model, dataset, global_dimension, name, )
[docs] def decay_matrix_implementation_index_independent( matrix: np.ndarray, rates: np.ndarray, global_axis: np.ndarray, model_axis: np.ndarray, dataset_model: DatasetModel, ): if isinstance(dataset_model.irf, IrfMultiGaussian): ( centers, widths, irf_scales, shift, backsweep, backsweep_period, ) = dataset_model.irf.parameter(None, global_axis) calculate_decay_matrix_gaussian_irf_on_index( matrix, rates, model_axis, centers - shift, widths, irf_scales, backsweep, backsweep_period, ) if dataset_model.irf.normalize: matrix /= np.sum(irf_scales) else: calculate_decay_matrix_no_irf(matrix, rates, model_axis)
[docs] def decay_matrix_implementation_index_dependent( matrix: np.ndarray, rates: np.ndarray, global_axis: np.ndarray, model_axis: np.ndarray, dataset_model: DatasetModel, ): all_centers, all_widths = [], [] backsweep, backsweep_period = False, None irf_scales = [] for global_index in range(global_axis.size): ( centers, widths, irf_scales, shift, backsweep, backsweep_period, ) = dataset_model.irf.parameter(global_index, global_axis) all_centers.append(centers - shift) all_widths.append(widths) calculate_decay_matrix_gaussian_irf( matrix, rates, model_axis, np.array(all_centers), np.array(all_widths), irf_scales, backsweep, backsweep_period, ) if dataset_model.irf.normalize: matrix /= np.sum(irf_scales)
[docs] @nb.jit(nopython=True, parallel=True) def calculate_decay_matrix_no_irf(matrix, rates, times): for n_r in nb.prange(rates.size): r_n = rates[n_r] for n_t in range(times.size): t_n = times[n_t] matrix[n_t, n_r] += np.exp(-r_n * t_n)
[docs] def retrieve_species_associated_data( dataset_model: DatasetModel, dataset: xr.Dataset, species: list[str], species_dimension: str, global_dimension: str, name: str, is_full_model: bool, as_global: bool, ): model_dimension = get_dataset_model_model_dimension(dataset_model) if as_global: model_dimension, global_dimension = global_dimension, model_dimension dataset.coords[species_dimension] = species matrix = dataset.global_matrix if as_global else dataset.matrix clp_dim = "global_clp_label" if as_global else "clp_label" if len(dataset.matrix.shape) == 3: # index dependent dataset["species_concentration"] = ( ( global_dimension, model_dimension, species_dimension, ), matrix.sel({clp_dim: species}).values, ) else: # index independent dataset["species_concentration"] = ( ( model_dimension, species_dimension, ), matrix.sel({clp_dim: species}).values, ) if not is_full_model: dataset[f"species_associated_{name}"] = ( ( global_dimension, species_dimension, ), dataset.clp.sel(clp_label=species).data, )
[docs] def retrieve_initial_concentration( dataset_model: DatasetModel, dataset: xr.Dataset, species_dimension: str, ): if ( not hasattr(dataset_model, "initial_concentration") or dataset_model.initial_concentration is None ): # For parallel and sequential decay we don't have dataset wide initial concentration # unless mixed with general decays return dataset["initial_concentration"] = ( (species_dimension,), dataset_model.initial_concentration.parameters, )
[docs] def retrieve_decay_associated_data( megacomplex: Megacomplex, dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str, name: str, ): species = megacomplex.get_compartments(dataset_model) initial_concentration = megacomplex.get_initial_concentration(dataset_model) k_matrix = megacomplex.get_k_matrix() matrix = k_matrix.full(species) matrix_reduced = k_matrix.reduced(species) a_matrix = megacomplex.get_a_matrix(dataset_model) rates = k_matrix.rates(species, initial_concentration) lifetimes = 1 / rates das = dataset[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T component_name = f"component_{megacomplex.label}" component_coords = { component_name: np.arange(1, rates.size + 1), f"rate_{megacomplex.label}": (component_name, rates), f"lifetime_{megacomplex.label}": (component_name, lifetimes), } das_coords = component_coords.copy() das_coords[global_dimension] = dataset.coords[global_dimension] das_name = f"decay_associated_{name}_{megacomplex.label}" das = xr.DataArray(das, dims=(global_dimension, component_name), coords=das_coords) initial_concentration = megacomplex.get_initial_concentration(dataset_model, normalized=False) species_name = f"species_{megacomplex.label}" a_matrix_coords = component_coords.copy() a_matrix_coords[species_name] = species a_matrix_coords[f"initial_concentration_{megacomplex.label}"] = ( species_name, initial_concentration, ) a_matrix_name = f"a_matrix_{megacomplex.label}" a_matrix = xr.DataArray(a_matrix, coords=a_matrix_coords, dims=(component_name, species_name)) to_species_name = f"to_species_{megacomplex.label}" from_species_name = f"from_species_{megacomplex.label}" k_matrix_name = f"k_matrix_{megacomplex.label}" k_matrix = xr.DataArray( matrix, coords=[(to_species_name, species), (from_species_name, species)] ) k_matrix_reduced_name = f"k_matrix_reduced_{megacomplex.label}" k_matrix_reduced = xr.DataArray( matrix_reduced, coords=[(to_species_name, species), (from_species_name, species)] ) dataset[das_name] = das dataset[a_matrix_name] = a_matrix dataset[k_matrix_name] = k_matrix dataset[k_matrix_reduced_name] = k_matrix_reduced
[docs] def retrieve_irf(dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str): if not isinstance(dataset_model.irf, IrfMultiGaussian) or "irf" in dataset: return irf = dataset_model.irf model_dimension = get_dataset_model_model_dimension(dataset_model) dataset["irf"] = ( (model_dimension), irf.calculate( index=0, global_axis=dataset.coords[global_dimension].values, model_axis=dataset.coords[model_dimension].values, ).data, ) center = irf.center if isinstance(irf.center, list) else [irf.center] width = irf.width if isinstance(irf.width, list) else [irf.width] dataset["irf_center"] = ("irf_nr", center) if len(center) > 1 else center[0] dataset["irf_width"] = ("irf_nr", width) if len(width) > 1 else width[0] if irf.shift is not None: dataset["irf_shift"] = (global_dimension, [center[0] - p.value for p in irf.shift]) if isinstance(irf, IrfSpectralMultiGaussian) and irf.dispersion_center: dataset["irf_center_location"] = ( ("irf_nr", global_dimension), irf.calculate_dispersion(dataset.coords["spectral"].values), ) # center_dispersion_1 for backwards compatibility (0.3-0.4.1) dataset["center_dispersion_1"] = dataset["irf_center_location"].sel(irf_nr=0)