Source code for glotaran.analysis.util

from __future__ import annotations

import itertools
from typing import Any
from typing import NamedTuple

import numba as nb
import numpy as np
import xarray as xr

from glotaran.model import DatasetModel
from glotaran.model import Model
from glotaran.parameter import ParameterGroup


[docs]class CalculatedMatrix(NamedTuple): clp_labels: list[str] matrix: np.ndarray
[docs]def find_overlap(a, b, rtol=1e-05, atol=1e-08): ovr_a = [] ovr_b = [] start_b = 0 for i, ai in enumerate(a): for j, bj in itertools.islice(enumerate(b), start_b, None): if np.isclose(ai, bj, rtol=rtol, atol=atol, equal_nan=False): ovr_a.append(i) ovr_b.append(j) elif bj > ai: # (more than tolerance) break # all the rest will be farther away else: # bj < ai (more than tolerance) start_b += 1 # ignore further tests of this item return (ovr_a, ovr_b)
[docs]def find_closest_index(index: float, axis: np.ndarray): return np.abs(axis - index).argmin()
[docs]def get_min_max_from_interval(interval, axis): minimum = np.abs(axis.values - interval[0]).argmin() if not np.isinf(interval[0]) else 0 maximum = ( np.abs(axis.values - interval[1]).argmin() + 1 if not np.isinf(interval[1]) else axis.size ) return slice(minimum, maximum)
[docs]def calculate_matrix( dataset_model: DatasetModel, indices: dict[str, int], as_global_model: bool = False, ) -> CalculatedMatrix: clp_labels = None matrix = None megacomplex_iterator = dataset_model.iterate_megacomplexes if as_global_model: megacomplex_iterator = dataset_model.iterate_global_megacomplexes dataset_model.swap_dimensions() for scale, megacomplex in megacomplex_iterator(): this_clp_labels, this_matrix = megacomplex.calculate_matrix(dataset_model, indices) if scale is not None: this_matrix *= scale if matrix is None: clp_labels = this_clp_labels matrix = this_matrix else: clp_labels, matrix = combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels) if as_global_model: dataset_model.swap_dimensions() return CalculatedMatrix(clp_labels, matrix)
[docs]def combine_matrix(matrix, this_matrix, clp_labels, this_clp_labels): tmp_clp_labels = clp_labels + [c for c in this_clp_labels if c not in clp_labels] tmp_matrix = np.zeros((matrix.shape[0], len(tmp_clp_labels)), dtype=np.float64) for idx, label in enumerate(tmp_clp_labels): if label in clp_labels: tmp_matrix[:, idx] += matrix[:, clp_labels.index(label)] if label in this_clp_labels: tmp_matrix[:, idx] += this_matrix[:, this_clp_labels.index(label)] return tmp_clp_labels, tmp_matrix
[docs]@nb.jit(nopython=True, parallel=True) def apply_weight(matrix, weight): for i in nb.prange(matrix.shape[1]): matrix[:, i] *= weight
[docs]def reduce_matrix( matrix: CalculatedMatrix, model: Model, parameters: ParameterGroup, index: Any | None, ) -> CalculatedMatrix: matrix = apply_relations(matrix, model, parameters, index) matrix = apply_constraints(matrix, model, index) return matrix
[docs]def apply_constraints( matrix: CalculatedMatrix, model: Model, index: Any | None, ) -> CalculatedMatrix: if len(model.clp_constraints) == 0: return matrix clp_labels = matrix.clp_labels removed_clp_labels = [ c.target for c in model.clp_constraints if c.target in clp_labels and c.applies(index) ] reduced_clp_labels = [c for c in clp_labels if c not in removed_clp_labels] mask = [label in reduced_clp_labels for label in clp_labels] reduced_matrix = matrix.matrix[:, mask] return CalculatedMatrix(reduced_clp_labels, reduced_matrix)
[docs]def apply_relations( matrix: CalculatedMatrix, model: Model, parameters: ParameterGroup, index: Any | None, ) -> CalculatedMatrix: if len(model.clp_relations) == 0: return matrix clp_labels = matrix.clp_labels relation_matrix = np.diagflat([1.0 for _ in clp_labels]) idx_to_delete = [] for relation in model.clp_relations: if relation.target in clp_labels and relation.applies(index): if relation.source not in clp_labels: continue relation = relation.fill(model, parameters) source_idx = clp_labels.index(relation.source) target_idx = clp_labels.index(relation.target) relation_matrix[target_idx, source_idx] = relation.parameter idx_to_delete.append(target_idx) reduced_clp_labels = [label for i, label in enumerate(clp_labels) if i not in idx_to_delete] relation_matrix = np.delete(relation_matrix, idx_to_delete, axis=1) reduced_matrix = matrix.matrix @ relation_matrix return CalculatedMatrix(reduced_clp_labels, reduced_matrix)
[docs]def retrieve_clps( model: Model, parameters: ParameterGroup, clp_labels: xr.DataArray, reduced_clp_labels: xr.DataArray, reduced_clps: xr.DataArray, index: Any | None, ) -> xr.DataArray: if len(model.clp_relations) == 0 and len(model.clp_constraints) == 0: return reduced_clps clps = np.zeros(len(clp_labels)) for i, label in enumerate(reduced_clp_labels): idx = clp_labels.index(label) clps[idx] = reduced_clps[i] for relation in model.clp_relations: relation = relation.fill(model, parameters) if ( relation.target in clp_labels and relation.applies(index) and relation.source in clp_labels ): source_idx = clp_labels.index(relation.source) target_idx = clp_labels.index(relation.target) clps[target_idx] = relation.parameter * clps[source_idx] return clps
[docs]def calculate_clp_penalties( model: Model, parameters: ParameterGroup, clp_labels: list[list[str]] | list[str], clps: list[np.ndarray], global_axis: np.ndarray, dataset_models: dict[str, DatasetModel], ) -> np.ndarray: # TODO: make a decision on how to handle clp_penalties per dataset # 1. sum up contributions per dataset on each dataset_axis (v0.4.1) # 2. sum up contributions on the global_axis (future?) penalties = [] for penalty in model.clp_area_penalties: penalty = penalty.fill(model, parameters) source_area = np.array([]) target_area = np.array([]) for _, dataset_model in dataset_models.items(): dataset_axis = dataset_model.get_global_axis() source_area = np.concatenate( [ source_area, _get_area( penalty.source, clp_labels, clps, penalty.source_intervals, global_axis, dataset_axis, ), ] ) target_area = np.concatenate( [ target_area, _get_area( penalty.target, clp_labels, clps, penalty.target_intervals, global_axis, dataset_axis, ), ] ) area_penalty = np.abs(np.sum(source_area) - penalty.parameter * np.sum(target_area)) penalties.append(area_penalty * penalty.weight) return np.asarray(penalties)
def _get_area( clp_label: str, clp_labels: list[list[str]], clps: list[np.ndarray], intervals: list[tuple[float, float]], global_axis: np.ndarray, dataset_axis: np.ndarray, ) -> np.ndarray: area = [] for interval in intervals: if interval[0] > global_axis[-1]: continue bounded_interval = ( max(interval[0], np.min(dataset_axis)), min(interval[1], np.max(dataset_axis)), ) start_idx, end_idx = get_idx_from_interval(bounded_interval, global_axis) for i in range(start_idx, end_idx + 1): index_clp_labels = clp_labels[i] if isinstance(clp_labels[0], list) else clp_labels if clp_label in index_clp_labels: area.append(clps[i][index_clp_labels.index(clp_label)]) return np.asarray(area) # TODO: normalize for distance on global axis
[docs]def get_idx_from_interval(interval: tuple[float, float], axis: np.ndarray) -> tuple[int, int]: """Retrieves start and end index of an interval on some axis Parameters ---------- interval : A tuple of floats with begin and end of the interval axis : Array like object which can be cast to np.array Returns ------- start, end : tuple of int """ start = np.abs(axis - interval[0]).argmin() if not np.isinf(interval[0]) else 0 end = np.abs(axis - interval[1]).argmin() if not np.isinf(interval[1]) else axis.size - 1 return start, end