Source code for glotaran.analysis.scheme

from __future__ import annotations

import functools
import pathlib
import warnings
from typing import Literal

import numpy as np
import xarray as xr
import yaml

import glotaran
from glotaran.model import Model
from glotaran.parameter import ParameterGroup


def _not_none(f):
    @functools.wraps(f)
    def decorator(self, value):
        if value is None:
            raise ValueError(f"{f.__name__} cannot be None")
        f(self, value)


[docs]class Scheme: def __init__( self, model: Model = None, parameters: ParameterGroup = None, data: dict[str, xr.DataArray | xr.Dataset] = None, group_tolerance: float = 0.0, non_negative_least_squares: bool = False, maximum_number_function_evaluations: int = None, ftol: float = 1e-8, gtol: float = 1e-8, xtol: float = 1e-8, optimization_method: Literal[ "TrustRegionReflection", "Dogbox", "Levenberg-Marquardt", ] = "TrustRegionReflection", ): self._model = model self._parameters = parameters self._group_tolerance = group_tolerance self._non_negative_least_squares = non_negative_least_squares self._maximum_number_function_evaluations = maximum_number_function_evaluations self._ftol = ftol self._gtol = gtol self._xtol = xtol self._optimization_method = optimization_method self._prepare_data(data)
[docs] @classmethod def from_yaml_file(cls, filename: str) -> Scheme: try: with open(filename) as f: try: scheme = yaml.safe_load(f) except Exception as e: raise ValueError(f"Error parsing scheme: {e}") except Exception as e: raise OSError(f"Error opening scheme: {e}") if "model" not in scheme: raise ValueError("Model file not specified.") try: model = glotaran.read_model_from_yaml_file(scheme["model"]) except Exception as e: raise ValueError(f"Error loading model: {e}") if "parameters" not in scheme: raise ValueError("Parameters file not specified.") path = scheme["parameters"] fmt = scheme.get("parameter_format", None) try: parameters = glotaran.parameter.ParameterGroup.from_file(path, fmt) except Exception as e: raise ValueError(f"Error loading parameters: {e}") if "data" not in scheme: raise ValueError("No data specified.") data = {} for label, path in scheme["data"].items(): path = pathlib.Path(path) fmt = path.suffix[1:] if path.suffix != "" else "nc" if "data_format" in scheme: fmt = scheme["data_format"] try: data[label] = glotaran.io.read_data_file(path, fmt=fmt) except Exception as e: raise ValueError(f"Error loading dataset '{label}': {e}") optimization_method = scheme.get("optimization_method", "TrustRegionReflection") nnls = scheme.get("non-negative-least-squares", False) nfev = scheme.get("maximum-number-function-evaluations", None) ftol = scheme.get("ftol", 1e-8) gtol = scheme.get("gtol", 1e-8) xtol = scheme.get("xtol", 1e-8) group_tolerance = scheme.get("group_tolerance", 0.0) return cls( model=model, parameters=parameters, data=data, non_negative_least_squares=nnls, maximum_number_function_evaluations=nfev, ftol=ftol, gtol=gtol, xtol=xtol, group_tolerance=group_tolerance, optimization_method=optimization_method, )
@property def model(self) -> Model: return self._model @property def parameters(self) -> ParameterGroup: return self._parameters @property def data(self) -> dict[str, xr.DataArray | xr.Dataset]: return self._data @property def non_negative_least_squares(self) -> bool: return self._non_negative_least_squares @property def maximum_number_function_evaluations(self) -> int: return self._maximum_number_function_evaluations @property def group_tolerance(self) -> float: return self._group_tolerance @property def ftol(self) -> float: return self._ftol @property def gtol(self) -> float: return self._gtol @property def xtol(self) -> float: return self._xtol @property def optimization_method(self) -> str: return self._optimization_method
[docs] def problem_list(self) -> list[str]: """Returns a list with all problems in the model and missing parameters.""" return self.model.problem_list(self.parameters)
[docs] def validate(self) -> str: """Returns a string listing all problems in the model and missing parameters.""" return self.model.validate(self.parameters)
[docs] def valid(self, parameters: ParameterGroup = None) -> bool: """Returns `True` if there are no problems with the model or the parameters, else `False`.""" return self.model.valid(parameters)
def _transpose_dataset(self, dataset): new_dims = [self.model.model_dimension, self.model.global_dimension] new_dims += [ dim for dim in dataset.dims if dim not in [self.model.model_dimension, self.model.global_dimension] ] return dataset.transpose(*new_dims) def _prepare_data(self, data: dict[str, xr.DataArray | xr.Dataset]): self._data = {} for label, dataset in data.items(): if self.model.model_dimension not in dataset.dims: raise ValueError( "Missing coordinates for dimension " f"'{self.model.model_dimension}' in data for dataset " f"'{label}'" ) if self.model.global_dimension not in dataset.dims: raise ValueError( "Missing coordinates for dimension " f"'{self.model.global_dimension}' in data for dataset " f"'{label}'" ) if isinstance(dataset, xr.DataArray): dataset = dataset.to_dataset(name="data") dataset = self._transpose_dataset(dataset) self._add_weight(label, dataset) # TODO: avoid computation if not requested l, s, r = np.linalg.svd(dataset.data, full_matrices=False) dataset["data_left_singular_vectors"] = (("time", "left_singular_value_index"), l) dataset["data_singular_values"] = (("singular_value_index"), s) dataset["data_right_singular_vectors"] = ( ("right_singular_value_index", "spectral"), r, ) self._data[label] = dataset
[docs] def markdown(self): s = self.model.markdown(parameters=self.parameters) s += "\n\n" s += "__Scheme__\n\n" s += f"* *nnls*: {self.nnls}\n" s += f"* *nfev*: {self.nfev}\n" s += f"* *group_tolerance*: {self.group_tolerance}\n" return s
def _add_weight(self, label, dataset): # if the user supplies a weight we ignore modeled weights if "weight" in dataset: if any(label in weight.datasets for weight in self.model.weights): warnings.warn( f"Ignoring model weight for dataset '{label}'" " because weight is already supplied by dataset." ) return global_axis = dataset.coords[self.model.global_dimension] model_axis = dataset.coords[self.model.model_dimension] for weight in self.model.weights: if label in weight.datasets: if "weight" not in dataset: dataset["weight"] = xr.DataArray( np.ones_like(dataset.data), coords=dataset.data.coords ) idx = {} if weight.global_interval is not None: idx[self.model.global_dimension] = _get_min_max_from_interval( weight.global_interval, global_axis ) if weight.model_interval is not None: idx[self.model.model_dimension] = _get_min_max_from_interval( weight.model_interval, model_axis ) dataset.weight[idx] *= weight.value
def _get_min_max_from_interval(interval, axis): minimum = np.abs(axis.values - interval[0]).argmin() if not np.isinf(interval[0]) else 0 maximum = ( np.abs(axis.values - interval[1]).argmin() + 1 if not np.isinf(interval[1]) else axis.size ) return slice(minimum, maximum)