"""Module containing the matrix provider classes."""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from dataclasses import replace
from typing import TYPE_CHECKING
from typing import Any
import numpy as np
import xarray as xr
from glotaran.model import DatasetGroup
from glotaran.model import DatasetModel
from glotaran.model.dataset_model import has_dataset_model_global_model
from glotaran.model.dataset_model import iterate_dataset_model_global_megacomplexes
from glotaran.model.dataset_model import iterate_dataset_model_megacomplexes
from glotaran.model.interval_item import IntervalItem
from glotaran.model.item import fill_item
from glotaran.optimization.data_provider import DataProvider
from glotaran.optimization.data_provider import DataProviderLinked
if TYPE_CHECKING:
from glotaran.typing.types import ArrayLike
[docs]
@dataclass
class MatrixContainer:
"""A container of matrix and the corresponding clp labels."""
clp_labels: list[str]
"""The clp labels."""
matrix: np.ndarray
"""The matrix."""
@property
def is_index_dependent(self) -> bool:
"""Check if the matrix is index dependent.
Returns
-------
bool
Whether the matrix is index dependent.
"""
return len(self.matrix.shape) == 3
[docs]
@staticmethod
def apply_weight(matrix: ArrayLike, weight: ArrayLike) -> ArrayLike:
"""Apply weight on a matrix.
Parameters
----------
matrix : ArrayLike
The matrix.
weight : ArrayLike
The weight.
Returns
-------
ArrayLike
The weighted matrix.
"""
return (matrix.T * weight).T
[docs]
def create_weighted_matrix(self, weight: ArrayLike) -> MatrixContainer:
"""Create a matrix container with a weighted matrix.
Parameters
----------
weight : ArrayLike
The weight.
Returns
-------
MatrixContainer
The weighted matrix.
"""
return replace(self, matrix=self.apply_weight(self.matrix, weight))
[docs]
def create_scaled_matrix(self, scale: float) -> MatrixContainer:
"""Create a matrix container with a scaled matrix.
Parameters
----------
scale : float
The scale.
Returns
-------
MatrixContainer
The scaled matrix.
"""
return replace(self, matrix=self.matrix * scale)
[docs]
class MatrixProvider:
"""A class to provide matrix calculations for optimization."""
def __init__(self, dataset_group: DatasetGroup):
"""Initialize a matrix provider for a dataset group.
Parameters
----------
dataset_group : DatasetGroup
The dataset group.
"""
self._group = dataset_group
self._matrix_containers: dict[str, MatrixContainer] = {}
self._global_matrix_containers: dict[str, MatrixContainer] = {}
self._data_provider: DataProvider
@property
def group(self) -> DatasetGroup:
"""Get the dataset group.
Returns
-------
DatasetGroup
The dataset group.
"""
return self._group
[docs]
def get_matrix_container(self, dataset_label: str) -> MatrixContainer:
"""Get the matrix container for a dataset on an index on the global axis.
Parameters
----------
dataset_label : str
The label of the dataset.
Returns
-------
MatrixContainer
The matrix container.
"""
return self._matrix_containers[dataset_label]
[docs]
def calculate_dataset_matrices(self):
"""Calculate the matrices of the datasets in the dataset group."""
for label, dataset_model in self.group.dataset_models.items():
model_axis = self._data_provider.get_model_axis(label)
global_axis = self._data_provider.get_global_axis(label)
self._matrix_containers[label] = self.calculate_dataset_matrix(
dataset_model, global_axis, model_axis
)
[docs]
@staticmethod
def calculate_dataset_matrix(
dataset_model: DatasetModel,
global_axis: ArrayLike,
model_axis: ArrayLike,
global_matrix: bool = False,
) -> MatrixContainer:
"""Calculate the matrix for a dataset on an index on the global axis.
Parameters
----------
dataset_model : DatasetModel
The dataset model.
global_axis: ArrayLike
The global axis.
model_axis: ArrayLike
The model axis.
global_matrix: bool
Calculate the global megacomplexes if `True`.
Returns
-------
MatrixContainer
The resulting matrix container.
"""
clp_labels: list[str] = []
matrix = None
megacomplex_iterator = iterate_dataset_model_megacomplexes(dataset_model)
if global_matrix:
megacomplex_iterator = iterate_dataset_model_global_megacomplexes(dataset_model)
model_axis, global_axis = global_axis, model_axis
for scale, megacomplex in megacomplex_iterator:
this_clp_labels, this_matrix = megacomplex.calculate_matrix( # type:ignore[union-attr]
dataset_model, global_axis, model_axis
)
if scale is not None:
this_matrix *= scale
if matrix is None:
clp_labels = this_clp_labels
matrix = this_matrix
else:
clp_labels, matrix = MatrixProvider.combine_megacomplex_matrices(
matrix, this_matrix, clp_labels, this_clp_labels
)
return MatrixContainer(clp_labels, matrix) # type:ignore[arg-type]
[docs]
@staticmethod
def combine_megacomplex_matrices(
matrix_left: ArrayLike,
matrix_right: ArrayLike,
clp_labels_left: list[str],
clp_labels_right: list[str],
) -> tuple[list[str], ArrayLike]:
"""Calculate the matrix for a dataset on an index on the global axis.
Parameters
----------
matrix_left: ArrayLike
The left matrix.
matrix_right: ArrayLike
The right matrix.
clp_labels_left: list[str]
The left clp labels.
clp_labels_right: list[str]
The right clp labels.
Returns
-------
tuple[list[str], ArrayLike]:
The combined clp labels and matrix.
"""
result_clp_labels = clp_labels_left + [
c for c in clp_labels_right if c not in clp_labels_left
]
result_clp_size = len(result_clp_labels)
if len(matrix_left.shape) < len(matrix_right.shape):
matrix_left, matrix_right = matrix_right, matrix_left
left_index_dependent = len(matrix_left.shape) == 3
right_index_dependent = len(matrix_right.shape) == 3
result_shape = (
(matrix_left.shape[0], matrix_left.shape[1], result_clp_size)
if left_index_dependent
else (matrix_left.shape[0], result_clp_size)
)
result_matrix = np.zeros(result_shape, dtype=np.float64)
for idx, label in enumerate(result_clp_labels):
if label in clp_labels_left:
if left_index_dependent:
result_matrix[:, :, idx] += matrix_left[:, :, clp_labels_left.index(label)]
else:
result_matrix[:, idx] += matrix_left[:, clp_labels_left.index(label)]
if label in clp_labels_right:
if left_index_dependent:
result_matrix[:, :, idx] += (
matrix_right[:, :, clp_labels_right.index(label)]
if right_index_dependent
else matrix_right[:, clp_labels_right.index(label)]
)
else:
result_matrix[:, idx] += matrix_right[:, clp_labels_right.index(label)]
return result_clp_labels, result_matrix
[docs]
@staticmethod
def does_interval_item_apply(prop: IntervalItem, index: int | None) -> bool:
"""Check if an interval item applies on an index.
Parameters
----------
prop : IntervalItem
The interval property.
index: int | None
The index to check.
Returns
-------
bool
Whether the property applies.
"""
if prop.has_interval() and index is None:
warnings.warn(
f"Interval property '{prop}' applies on a matrix which is "
f"not index dependent. This will be an error in 0.9.0. Set "
"'index_dependent: true' on the dataset model to fix the issue."
)
return True
return prop.applies(index)
[docs]
def reduce_matrix(
self,
matrix: MatrixContainer,
global_axis: ArrayLike,
) -> list[MatrixContainer]:
"""Reduce a matrix.
Applies constraints and relations.
Parameters
----------
matrix : MatrixContainer
The matrix.
global_axis: ArrayLike
The global axis.
Returns
-------
MatrixContainer
The resulting matrix container.
"""
result = (
[
MatrixContainer(matrix.clp_labels, matrix.matrix[i, :, :])
for i in range(global_axis.size)
]
if matrix.is_index_dependent
else [matrix] * global_axis.size
)
result = self.apply_relations(result, global_axis)
result = self.apply_constraints(result, global_axis)
return result
[docs]
def apply_constraints(
self,
matrices: list[MatrixContainer],
global_axis: ArrayLike,
) -> list[MatrixContainer]:
"""Apply constraints on a matrix.
Parameters
----------
matrices: list[MatrixContainer],
The matrices.
global_axis: ArrayLike
The global axis.
Returns
-------
MatrixContainer
The resulting matrix container.
"""
model = self.group.model
if len(model.clp_constraints) == 0:
return matrices
for i, index in enumerate(global_axis):
matrix = matrices[i]
clp_labels = matrix.clp_labels
removed_clp_labels = [
c.target
for c in model.clp_constraints
if c.target in clp_labels and self.does_interval_item_apply(c, index)
]
if len(removed_clp_labels) == 0:
continue
reduced_clp_labels = [c for c in clp_labels if c not in removed_clp_labels]
mask = [label in reduced_clp_labels for label in clp_labels]
reduced_matrix = matrix.matrix[:, mask]
matrices[i] = MatrixContainer(reduced_clp_labels, reduced_matrix)
return matrices
[docs]
def apply_relations(
self,
matrices: list[MatrixContainer],
global_axis: ArrayLike,
) -> list[MatrixContainer]:
"""Apply relations on a matrix.
Parameters
----------
matrices: list[MatrixContainer],
The matrices.
global_axis: ArrayLike
The global axis.
Returns
-------
MatrixContainer
The resulting matrix container.
"""
model = self.group.model
parameters = self.group.parameters
if len(model.clp_relations) == 0:
return matrices
for i, index in enumerate(global_axis):
matrix = matrices[i]
clp_labels = matrix.clp_labels
relation_matrix = np.diagflat([1.0 for _ in clp_labels])
idx_to_delete = []
for relation in model.clp_relations:
if relation.target in clp_labels and self.does_interval_item_apply(
relation, index
):
if relation.source not in clp_labels:
continue
relation = fill_item(relation, model, parameters) # type:ignore[arg-type]
source_idx = clp_labels.index(relation.source)
target_idx = clp_labels.index(relation.target)
relation_matrix[target_idx, source_idx] = relation.parameter
idx_to_delete.append(target_idx)
if len(idx_to_delete) == 0:
continue
reduced_clp_labels = [
label for i, label in enumerate(clp_labels) if i not in idx_to_delete
]
relation_matrix = np.delete(relation_matrix, idx_to_delete, axis=1)
reduced_matrix = matrix.matrix @ relation_matrix
matrices[i] = MatrixContainer(reduced_clp_labels, reduced_matrix)
return matrices
[docs]
def get_result(self) -> tuple[dict[str, xr.DataArray], dict[str, xr.DataArray]]:
"""Get the results of the matrix calculations.
Returns
-------
tuple[dict[str, xr.DataArray], dict[str, xr.DataArray]]
A tuple of the matrices and global matrices.
.. # noqa: DAR202
.. # noqa: DAR401
"""
matrices = {}
global_matrices = {}
for label, matrix_container in self._matrix_containers.items():
model_dimension = self._data_provider.get_model_dimension(label)
model_axis = self._data_provider.get_model_axis(label)
matrix_coords: tuple[tuple[str, Any], tuple[str, Any], tuple[str, list[str]]] | tuple[
tuple[str, Any], tuple[str, list[str]]
] = (
(model_dimension, model_axis),
("clp_label", matrix_container.clp_labels),
)
if matrix_container.is_index_dependent:
global_dimension = self._data_provider.get_global_dimension(label)
global_axis = self._data_provider.get_global_axis(label)
matrix_coords = (
(global_dimension, global_axis),
matrix_coords[0],
matrix_coords[1],
)
matrices[label] = xr.DataArray(matrix_container.matrix, coords=matrix_coords)
for label, matrix_container in self._global_matrix_containers.items():
global_dimension = self._data_provider.get_global_dimension(label)
global_axis = self._data_provider.get_global_axis(label)
global_matrices[label] = xr.DataArray(
matrix_container.matrix,
coords=(
(global_dimension, global_axis),
("global_clp_label", matrix_container.clp_labels),
),
)
return global_matrices, matrices
[docs]
def calculate(self):
"""Calculate the matrices for optimization.
.. # noqa: DAR401
"""
raise NotImplementedError
@property
def number_of_clps(self) -> int:
"""Return number of conditionally linear parameters.
Raises
------
NotImplementedError
This property needs to be implemented by subclasses.
See Also
--------
MatrixProviderUnlinked
MatrixProviderLinked
"""
raise NotImplementedError
[docs]
class MatrixProviderUnlinked(MatrixProvider):
"""A class to provide matrix calculations for optimization of unlinked dataset groups."""
def __init__(self, group: DatasetGroup, data_provider: DataProvider):
"""Initialize a matrix provider for an unlinked dataset group.
Parameters
----------
dataset_group : DatasetGroup
The dataset group.
data_provider : DataProvider
The data provider.
"""
super().__init__(group)
self._data_provider = data_provider
self._prepared_matrix_container: dict[str, list[MatrixContainer]] = {}
self._full_matrices: dict[str, ArrayLike] = {}
[docs]
def get_global_matrix_container(self, dataset_label: str) -> MatrixContainer:
"""Get the global matrix container for a dataset.
Parameters
----------
dataset_label : str
The label of the dataset.
Returns
-------
MatrixContainer
The matrix container.
"""
return self._global_matrix_containers[dataset_label]
[docs]
def get_prepared_matrix_container(
self, dataset_label: str, global_index: int
) -> MatrixContainer:
"""Get the prepared matrix container for a dataset on an index on the global axis.
Parameters
----------
dataset_label : str
The label of the dataset.
global_index : int
The index on the global axis.
Returns
-------
MatrixContainer
The matrix container.
"""
return self._prepared_matrix_container[dataset_label][global_index]
[docs]
def get_full_matrix(self, dataset_label: str) -> ArrayLike:
"""Get the full matrix of a dataset.
Parameters
----------
dataset_label : str
The label of the dataset.
Returns
-------
ArrayLike
The matrix.
"""
return self._full_matrices[dataset_label]
[docs]
def calculate(self):
"""Calculate the matrices for optimization."""
self.calculate_dataset_matrices()
self.calculate_global_matrices()
self.calculate_prepared_matrices()
self.calculate_full_matrices()
[docs]
def calculate_global_matrices(self):
"""Calculate the global matrices of the datasets in the dataset group."""
for label, dataset_model in self.group.dataset_models.items():
if has_dataset_model_global_model(dataset_model):
model_axis = self._data_provider.get_model_axis(label)
global_axis = self._data_provider.get_global_axis(label)
self._global_matrix_containers[label] = self.calculate_dataset_matrix(
dataset_model, global_axis, model_axis, global_matrix=True
)
[docs]
def calculate_prepared_matrices(self):
"""Calculate the prepared matrices of the datasets in the dataset group."""
for label, dataset_model in self.group.dataset_models.items():
if has_dataset_model_global_model(dataset_model):
continue
scale = float(dataset_model.scale or 1)
weight = self._data_provider.get_weight(label)
self._prepared_matrix_container[label] = self.reduce_matrix(
self.get_matrix_container(label).create_scaled_matrix(scale),
self._data_provider.get_global_axis(label),
)
if weight is not None:
self._prepared_matrix_container[label] = [
matrix.create_weighted_matrix(weight[:, i])
for i, matrix in enumerate(self._prepared_matrix_container[label])
]
[docs]
def calculate_full_matrices(self):
"""Calculate the full matrices of the datasets in the dataset group."""
for label, dataset_model in self.group.dataset_models.items():
if has_dataset_model_global_model(dataset_model):
global_matrix_container = self.get_global_matrix_container(label)
global_matrix = global_matrix_container.matrix
matrix_container = self.get_matrix_container(label)
matrix = matrix_container.matrix
if matrix_container.is_index_dependent:
full_matrix = np.concatenate(
[
np.kron(global_matrix[i, :], matrix[i, :, :])
for i in range(matrix.shape[0])
]
)
else:
full_matrix = np.kron(global_matrix, matrix)
weight = self._data_provider.get_flattened_weight(label)
if weight is not None:
full_matrix = MatrixContainer.apply_weight(full_matrix, weight)
self._full_matrices[label] = full_matrix
@property
def number_of_clps(self) -> int:
"""Return number of conditionally linear parameters.
Returns
-------
int
"""
nr_of_clps = 0
for dataset_label, dataset_model in self.group.dataset_models.items():
if has_dataset_model_global_model(dataset_model):
model_clp_labels = self.get_matrix_container(dataset_label).clp_labels
global_clp_labels = self.get_global_matrix_container(dataset_label).clp_labels
nr_of_clps += len(model_clp_labels) * len(global_clp_labels)
else:
global_axis_indexes = range(
len(self._data_provider.get_global_axis(dataset_label))
)
nr_of_clps += sum(
len(self.get_prepared_matrix_container(dataset_label, index).clp_labels)
for index in global_axis_indexes
)
return nr_of_clps
[docs]
class MatrixProviderLinked(MatrixProvider):
"""A class to provide matrix calculations for optimization of linked dataset groups."""
def __init__(self, group: DatasetGroup, data_provider: DataProviderLinked):
"""Initialize a matrix provider for a linked dataset group.
Parameters
----------
dataset_group : DatasetGroup
The dataset group.
data_provider : DataProviderLinked
The data provider.
"""
super().__init__(group)
self._data_provider: DataProviderLinked = data_provider
self._aligned_full_clp_labels: list[list[str]] = [
None # type:ignore[list-item]
] * self._data_provider.aligned_global_axis.size
self._aligned_matrices: list[MatrixContainer] = [
None # type:ignore[list-item]
] * self._data_provider.aligned_global_axis.size
@property
def aligned_full_clp_labels(self) -> list[list[str]]:
"""Get the aligned full clp labels.
Returns
-------
list[list[str]]
The full aligned clp labels.
"""
return self._aligned_full_clp_labels
[docs]
def get_aligned_matrix_container(self, global_index: int) -> MatrixContainer:
"""Get the aligned matrix container for an index on the aligned global axis.
Parameters
----------
global_index : int
The index on the global axis.
Returns
-------
MatrixContainer
The matrix container.
"""
return self._aligned_matrices[global_index]
[docs]
def calculate(self):
"""Calculate the matrices for optimization."""
self.calculate_dataset_matrices()
self.calculate_aligned_matrices()
[docs]
def calculate_aligned_matrices(self):
"""Calculate the aligned matrices of the dataset group."""
full_clp_labels = self.align_full_clp_labels()
for i, global_index_value in enumerate(self._data_provider.aligned_global_axis):
matrix_containers = []
group_label = self._data_provider.get_aligned_group_label(i)
for label, index in zip(
self._data_provider.group_definitions[group_label],
self._data_provider.get_aligned_dataset_indices(i),
):
matrix_container_temp = self._matrix_containers[label]
if matrix_container_temp.is_index_dependent:
matrix_containers.append(
MatrixContainer(
clp_labels=matrix_container_temp.clp_labels,
matrix=matrix_container_temp.matrix[index],
)
)
else:
matrix_containers.append(matrix_container_temp)
matrix_scales = [
self.group.dataset_models[label].scale
if self.group.dataset_models[label].scale is not None
else 1
for label in self._data_provider.group_definitions[group_label]
]
group_matrix = self.align_matrices(
matrix_containers, matrix_scales # type:ignore[arg-type]
)
self._aligned_full_clp_labels[i] = full_clp_labels[group_label]
group_matrix_single = self.reduce_matrix(
group_matrix, np.array([self._data_provider.aligned_global_axis[i]])
)[0]
weight = self._data_provider.get_aligned_weight(i)
if weight is not None:
group_matrix_single = group_matrix_single.create_weighted_matrix(weight)
self._aligned_matrices[i] = group_matrix_single
[docs]
def align_full_clp_labels(self) -> dict[str, list[str]]:
"""Align the unreduced clp labels.
Returns
-------
dict[str, list[str]]
The aligned clp for every group.
"""
aligned_full_clp_labels: dict[str, list[str]] = {}
for (
group_label,
dataset_labels,
) in self._data_provider.group_definitions.items():
aligned_full_clp_labels[group_label] = []
for dataset_label in dataset_labels:
aligned_full_clp_labels[group_label] += [
label
for label in self.get_matrix_container(dataset_label).clp_labels
if label not in aligned_full_clp_labels[group_label]
]
return aligned_full_clp_labels
[docs]
@staticmethod
def align_matrices(matrices: list[MatrixContainer], scales: list[float]) -> MatrixContainer:
"""Align matrices.
Parameters
----------
matrices : list[MatrixContainer]
The matrices to align.
scales : list[float]
The scales of the matrices.
Returns
-------
MatrixContainer
The aligned matrix container.
"""
if len(matrices) == 1:
return matrices[0]
masks = []
full_clp_labels: list[str] = []
sizes = []
dim1 = 0
for matrix in matrices:
clp_labels = matrix.clp_labels
model_axis_size = matrix.matrix.shape[0]
sizes.append(model_axis_size)
dim1 += model_axis_size
if len(full_clp_labels) == 0:
full_clp_labels = clp_labels.copy()
masks.append([i for i, _ in enumerate(clp_labels)])
else:
mask = []
for c in clp_labels:
if c not in full_clp_labels:
full_clp_labels.append(c)
mask.append(full_clp_labels.index(c))
masks.append(mask)
dim2 = len(full_clp_labels)
full_matrix = np.zeros((dim1, dim2), dtype=np.float64)
start = 0
for i, m in enumerate(matrices):
end = start + sizes[i]
full_matrix[start:end, masks[i]] = m.matrix * scales[i]
start = end
return MatrixContainer(full_clp_labels, full_matrix)
@property
def number_of_clps(self) -> int:
"""Return number of conditionally linear parameters.
Returns
-------
int
"""
global_axis_indexes = range(len(self._data_provider.aligned_global_axis))
return sum(
len(self.get_aligned_matrix_container(index).clp_labels)
for index in global_axis_indexes
)