Source code for glotaran.builtin.megacomplexes.decay.util

from __future__ import annotations

import numba as nb
import numpy as np
import xarray as xr

from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian
from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian
from glotaran.model import DatasetModel
from glotaran.model import Megacomplex


[docs]def index_dependent(dataset_model: DatasetModel) -> bool: """Determine if a dataset_model is index dependent. Parameters ---------- dataset_model : DatasetModel A dataset model instance. Returns ------- bool Returns True if the dataset_model has an IRF that is index dependent (e.g. has dispersion). """ return ( isinstance(dataset_model.irf, IrfMultiGaussian) and dataset_model.irf.is_index_dependent() )
[docs]def calculate_matrix( megacomplex: Megacomplex, dataset_model: DatasetModel, indices: dict[str, int], **kwargs, ): compartments = megacomplex.get_compartments(dataset_model) initial_concentration = megacomplex.get_initial_concentration(dataset_model) k_matrix = megacomplex.get_k_matrix() rates = k_matrix.rates(compartments, initial_concentration) global_dimension = dataset_model.get_global_dimension() global_index = indices.get(global_dimension) global_axis = dataset_model.get_global_axis() model_axis = dataset_model.get_model_axis() # init the matrix size = (model_axis.size, rates.size) matrix = np.zeros(size, dtype=np.float64) decay_matrix_implementation( matrix, rates, global_index, global_axis, model_axis, dataset_model ) if not np.all(np.isfinite(matrix)): raise ValueError( f"Non-finite concentrations for K-Matrix '{k_matrix.label}':\n" f"{k_matrix.matrix_as_markdown(fill_parameters=True)}" ) # apply A matrix matrix = matrix @ megacomplex.get_a_matrix(dataset_model) # done return compartments, matrix
[docs]def collect_megacomplexes(dataset_model: DatasetModel, as_global: bool) -> list[Megacomplex]: from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex from glotaran.builtin.megacomplexes.decay.decay_parallel_megacomplex import ( DecayParallelMegacomplex, ) from glotaran.builtin.megacomplexes.decay.decay_sequential_megacomplex import ( DecaySequentialMegacomplex, ) return list( filter( lambda m: isinstance( m, (DecayMegacomplex, DecayParallelMegacomplex, DecaySequentialMegacomplex) ), dataset_model.global_megacomplex if as_global else dataset_model.megacomplex, ) )
[docs]def finalize_data( dataset_model: DatasetModel, dataset: xr.Dataset, is_full_model: bool = False, as_global: bool = False, ): species_dimension = "decay_species" if as_global else "species" if species_dimension in dataset.coords: # The first decay megacomplescomplex called will finalize the data for all # decay megacomplexes. return decay_megacomplexes = collect_megacomplexes(dataset_model, as_global) global_dimension = dataset_model.get_global_dimension() name = "images" if global_dimension == "pixel" else "spectra" all_species = [] for megacomplex in decay_megacomplexes: for species in megacomplex.get_compartments(dataset_model): if species not in all_species: all_species.append(species) retrieve_species_associated_data( dataset_model, dataset, all_species, species_dimension, global_dimension, name, is_full_model, as_global, ) retrieve_initial_concentration( dataset_model, dataset, species_dimension, ) retrieve_irf(dataset_model, dataset, global_dimension) if not is_full_model: for megacomplex in decay_megacomplexes: retrieve_decay_associated_data( megacomplex, dataset_model, dataset, global_dimension, name, )
[docs]def decay_matrix_implementation( matrix: np.ndarray, rates: np.ndarray, global_index: int, global_axis: np.ndarray, model_axis: np.ndarray, dataset_model: DatasetModel, ): if isinstance(dataset_model.irf, IrfMultiGaussian): ( centers, widths, irf_scales, shift, backsweep, backsweep_period, ) = dataset_model.irf.parameter(global_index, global_axis) for center, width, irf_scale in zip(centers, widths, irf_scales): calculate_decay_matrix_gaussian_irf( matrix, rates, model_axis, center - shift, width, irf_scale, backsweep, backsweep_period, ) if dataset_model.irf.normalize: matrix /= np.sum(irf_scale) else: calculate_decay_matrix_no_irf(matrix, rates, model_axis)
[docs]@nb.jit(nopython=True, parallel=True) def calculate_decay_matrix_no_irf(matrix, rates, times): for n_r in nb.prange(rates.size): r_n = rates[n_r] for n_t in range(times.size): t_n = times[n_t] matrix[n_t, n_r] += np.exp(-r_n * t_n)
sqrt2 = np.sqrt(2)
[docs]@nb.jit(nopython=True, parallel=True) def calculate_decay_matrix_gaussian_irf( matrix, rates, times, center, width, scale, backsweep, backsweep_period ): """Calculates a decay matrix with a gaussian irf.""" for n_r in nb.prange(rates.size): r_n = rates[n_r] backsweep_valid = abs(r_n) * backsweep_period > 0.001 alpha = (r_n * width) / sqrt2 for n_t in nb.prange(times.size): t_n = times[n_t] beta = (t_n - center) / (width * sqrt2) thresh = beta - alpha if thresh < -1: matrix[n_t, n_r] += scale * 0.5 * erfcx(-thresh) * np.exp(-beta * beta) else: matrix[n_t, n_r] += ( scale * 0.5 * (1 + erf(thresh)) * np.exp(alpha * (alpha - 2 * beta)) ) if backsweep and backsweep_valid: x1 = np.exp(-r_n * (t_n - center + backsweep_period)) x2 = np.exp(-r_n * ((backsweep_period / 2) - (t_n - center))) x3 = np.exp(-r_n * backsweep_period) matrix[n_t, n_r] += scale * (x1 + x2) / (1 - x3)
import ctypes # noqa: E402 # This is a work around to use scipy.special function with numba from numba.extending import get_cython_function_address # noqa: E402 _dble = ctypes.c_double functype = ctypes.CFUNCTYPE(_dble, _dble) erf_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erf") erfcx_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erfcx") erf = functype(erf_addr) erfcx = functype(erfcx_addr)
[docs]def retrieve_species_associated_data( dataset_model: DatasetModel, dataset: xr.Dataset, species: list[str], species_dimension: str, global_dimension: str, name: str, is_full_model: bool, as_global: bool, ): model_dimension = dataset_model.get_model_dimension() if as_global: model_dimension, global_dimension = global_dimension, model_dimension dataset.coords[species_dimension] = species matrix = dataset.global_matrix if as_global else dataset.matrix clp_dim = "global_clp_label" if as_global else "clp_label" if len(dataset.matrix.shape) == 3: # index dependent dataset["species_concentration"] = ( ( global_dimension, model_dimension, species_dimension, ), matrix.sel({clp_dim: species}).values, ) else: # index independent dataset["species_concentration"] = ( ( model_dimension, species_dimension, ), matrix.sel({clp_dim: species}).values, ) if not is_full_model: dataset[f"species_associated_{name}"] = ( ( global_dimension, species_dimension, ), dataset.clp.sel(clp_label=species).data, )
[docs]def retrieve_initial_concentration( dataset_model: DatasetModel, dataset: xr.Dataset, species_dimension: str, ): if ( not hasattr(dataset_model, "initial_concentration") or dataset_model.initial_concentration is None ): # For parallel and sequential decay we don't have dataset wide initial concentration # unless mixed with general decays return dataset["initial_concentration"] = ( (species_dimension,), dataset_model.initial_concentration.parameters, )
[docs]def retrieve_decay_associated_data( megacomplex: Megacomplex, dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str, name: str, ): species = megacomplex.get_compartments(dataset_model) initial_concentration = megacomplex.get_initial_concentration(dataset_model) k_matrix = megacomplex.get_k_matrix() matrix = k_matrix.full(species) matrix_reduced = k_matrix.reduced(species) a_matrix = megacomplex.get_a_matrix(dataset_model) rates = k_matrix.rates(species, initial_concentration) lifetimes = 1 / rates das = dataset[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T component_name = f"component_{megacomplex.label}" component_coords = { component_name: np.arange(1, rates.size + 1), f"rate_{megacomplex.label}": (component_name, rates), f"lifetime_{megacomplex.label}": (component_name, lifetimes), } das_coords = component_coords.copy() das_coords[global_dimension] = dataset.coords[global_dimension] das_name = f"decay_associated_{name}_{megacomplex.label}" das = xr.DataArray(das, dims=(global_dimension, component_name), coords=das_coords) initial_concentration = megacomplex.get_initial_concentration(dataset_model, normalized=False) species_name = f"species_{megacomplex.label}" a_matrix_coords = component_coords.copy() a_matrix_coords[species_name] = species a_matrix_coords[f"initial_concentration_{megacomplex.label}"] = ( species_name, initial_concentration, ) a_matrix_name = f"a_matrix_{megacomplex.label}" a_matrix = xr.DataArray(a_matrix, coords=a_matrix_coords, dims=(component_name, species_name)) to_species_name = f"to_species_{megacomplex.label}" from_species_name = f"from_species_{megacomplex.label}" k_matrix_name = f"k_matrix_{megacomplex.label}" k_matrix = xr.DataArray( matrix, coords=[(to_species_name, species), (from_species_name, species)] ) k_matrix_reduced_name = f"k_matrix_reduced_{megacomplex.label}" k_matrix_reduced = xr.DataArray( matrix_reduced, coords=[(to_species_name, species), (from_species_name, species)] ) dataset[das_name] = das dataset[a_matrix_name] = a_matrix dataset[k_matrix_name] = k_matrix dataset[k_matrix_reduced_name] = k_matrix_reduced
[docs]def retrieve_irf(dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str): if not isinstance(dataset_model.irf, IrfMultiGaussian) or "irf" in dataset: return irf = dataset_model.irf model_dimension = dataset_model.get_model_dimension() dataset["irf"] = ( (model_dimension), irf.calculate( index=0, global_axis=dataset.coords[global_dimension].values, model_axis=dataset.coords[model_dimension].values, ).data, ) center = irf.center if isinstance(irf.center, list) else [irf.center] width = irf.width if isinstance(irf.width, list) else [irf.width] dataset["irf_center"] = ("irf_nr", center) if len(center) > 1 else center[0] dataset["irf_width"] = ("irf_nr", width) if len(width) > 1 else width[0] if irf.shift is not None: dataset["irf_shift"] = (global_dimension, [center[0] - p.value for p in irf.shift]) if isinstance(irf, IrfSpectralMultiGaussian) and irf.dispersion_center: dataset["irf_center_location"] = ( ("irf_nr", global_dimension), irf.calculate_dispersion(dataset.coords["spectral"].values), ) # center_dispersion_1 for backwards compatibility (0.3-0.4.1) dataset["center_dispersion_1"] = dataset["irf_center_location"].sel(irf_nr=0)