Source code for glotaran.builtin.megacomplexes.decay.util

from __future__ import annotations

from typing import TYPE_CHECKING

import numba as nb
import numpy as np
import xarray as xr

from glotaran.builtin.megacomplexes.decay.irf import IrfMultiGaussian
from glotaran.builtin.megacomplexes.decay.irf import IrfSpectralMultiGaussian
from glotaran.model import DatasetModel

if TYPE_CHECKING:
    from glotaran.builtin.megacomplexes.decay.decay_megacomplex import DecayMegacomplex


[docs]def decay_matrix_implementation( matrix: np.ndarray, rates: np.ndarray, global_index: int, global_axis: np.ndarray, model_axis: np.ndarray, dataset_model: DatasetModel, ): if isinstance(dataset_model.irf, IrfMultiGaussian): ( centers, widths, irf_scales, shift, backsweep, backsweep_period, ) = dataset_model.irf.parameter(global_index, global_axis) for center, width, irf_scale in zip(centers, widths, irf_scales): calculate_decay_matrix_gaussian_irf( matrix, rates, model_axis, center - shift, width, irf_scale, backsweep, backsweep_period, ) if dataset_model.irf.normalize: matrix /= np.sum(irf_scale) else: calculate_decay_matrix_no_irf(matrix, rates, model_axis)
[docs]@nb.jit(nopython=True, parallel=True) def calculate_decay_matrix_no_irf(matrix, rates, times): for n_r in nb.prange(rates.size): r_n = rates[n_r] for n_t in range(times.size): t_n = times[n_t] matrix[n_t, n_r] += np.exp(r_n * t_n)
sqrt2 = np.sqrt(2)
[docs]@nb.jit(nopython=True, parallel=True) def calculate_decay_matrix_gaussian_irf( matrix, rates, times, center, width, scale, backsweep, backsweep_period ): """Calculates a decay matrix with a gaussian irf.""" for n_r in nb.prange(rates.size): r_n = -rates[n_r] backsweep_valid = abs(r_n) * backsweep_period > 0.001 alpha = (r_n * width) / sqrt2 for n_t in nb.prange(times.size): t_n = times[n_t] beta = (t_n - center) / (width * sqrt2) thresh = beta - alpha if thresh < -1: matrix[n_t, n_r] += scale * 0.5 * erfcx(-thresh) * np.exp(-beta * beta) else: matrix[n_t, n_r] += ( scale * 0.5 * (1 + erf(thresh)) * np.exp(alpha * (alpha - 2 * beta)) ) if backsweep and backsweep_valid: x1 = np.exp(-r_n * (t_n - center + backsweep_period)) x2 = np.exp(-r_n * ((backsweep_period / 2) - (t_n - center))) x3 = np.exp(-r_n * backsweep_period) matrix[n_t, n_r] += scale * (x1 + x2) / (1 - x3)
import ctypes # noqa: E402 # This is a work around to use scipy.special function with numba from numba.extending import get_cython_function_address # noqa: E402 _dble = ctypes.c_double functype = ctypes.CFUNCTYPE(_dble, _dble) erf_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erf") erfcx_addr = get_cython_function_address("scipy.special.cython_special", "__pyx_fuse_1erfcx") erf = functype(erf_addr) erfcx = functype(erfcx_addr)
[docs]def retrieve_species_associated_data( dataset_model: DatasetModel, dataset: xr.Dataset, species_dimension: str, global_dimension: str, name: str, is_full_model: bool, as_global: bool, ): species = dataset_model.initial_concentration.compartments model_dimension = dataset_model.get_model_dimension() if as_global: model_dimension, global_dimension = global_dimension, model_dimension dataset.coords[species_dimension] = species matrix = dataset.global_matrix if as_global else dataset.matrix clp_dim = "global_clp_label" if as_global else "clp_label" if len(dataset.matrix.shape) == 3: # index dependent dataset["species_concentration"] = ( ( global_dimension, model_dimension, species_dimension, ), matrix.sel({clp_dim: species}).values, ) else: # index independent dataset["species_concentration"] = ( ( model_dimension, species_dimension, ), matrix.sel({clp_dim: species}).values, ) if not is_full_model: dataset[f"species_associated_{name}"] = ( ( global_dimension, species_dimension, ), dataset.clp.sel(clp_label=species).data, )
[docs]def retrieve_decay_associated_data( megacomplex: DecayMegacomplex, dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str, name: str, multiple_complexes: bool, ): k_matrix = megacomplex.full_k_matrix() species = dataset_model.initial_concentration.compartments species = [c for c in species if c in k_matrix.involved_compartments()] matrix = k_matrix.full(species) matrix_reduced = k_matrix.reduced(species) a_matrix = k_matrix.a_matrix(dataset_model.initial_concentration.normalized()) rates = k_matrix.rates(dataset_model.initial_concentration) lifetimes = 1 / rates das = dataset[f"species_associated_{name}"].sel(species=species).values @ a_matrix.T component_coords = { "component": np.arange(rates.size), "rate": ("component", rates), "lifetime": ("component", lifetimes), } das_coords = component_coords.copy() das_coords[global_dimension] = dataset.coords[global_dimension] das_name = f"decay_associated_{name}" das = xr.DataArray(das, dims=(global_dimension, "component"), coords=das_coords) a_matrix_coords = component_coords.copy() a_matrix_coords["species"] = species a_matrix_name = "a_matrix" a_matrix = xr.DataArray(a_matrix, coords=a_matrix_coords, dims=("component", "species")) k_matrix_name = "k_matrix" k_matrix = xr.DataArray(matrix, coords=[("to_species", species), ("from_species", species)]) k_matrix_reduced_name = "k_matrix_reduced" k_matrix_reduced = xr.DataArray( matrix_reduced, coords=[("to_species", species), ("from_species", species)] ) if multiple_complexes: das_name = f"decay_associated_{name}_{megacomplex.label}" das = das.rename(component=f"component_{megacomplex.label}") a_matrix_name = f"a_matrix_{megacomplex.label}" a_matrix = a_matrix.rename(component=f"component_{megacomplex.label}") k_matrix_name = f"k_matrix_{megacomplex.label}" k_matrix_reduced_name = f"k_matrix_reduced_{megacomplex.label}" dataset[das_name] = das dataset[a_matrix_name] = a_matrix dataset[k_matrix_name] = k_matrix dataset[k_matrix_reduced_name] = k_matrix_reduced
[docs]def retrieve_irf(dataset_model: DatasetModel, dataset: xr.Dataset, global_dimension: str): irf = dataset_model.irf model_dimension = dataset_model.get_model_dimension() dataset["irf"] = ( (model_dimension), irf.calculate( index=0, global_axis=dataset.coords[global_dimension].values, model_axis=dataset.coords[model_dimension].values, ).data, ) center = irf.center if isinstance(irf.center, list) else [irf.center] width = irf.width if isinstance(irf.width, list) else [irf.width] dataset["irf_center"] = ("irf_nr", center) if len(center) > 1 else center[0] dataset["irf_width"] = ("irf_nr", width) if len(width) > 1 else width[0] if isinstance(irf, IrfSpectralMultiGaussian) and irf.dispersion_center: dataset["irf_center_location"] = ( ("irf_nr", global_dimension), irf.calculate_dispersion(dataset.coords["spectral"].values), ) # center_dispersion_1 for backwards compatibility (0.3-0.4.1) dataset["center_dispersion_1"] = dataset["irf_center_location"].sel(irf_nr=0)