"""The model decorator."""
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
import numpy as np
import xarray as xr
import glotaran # TODO: refactor to postponed type annotation
from glotaran.parameter import ParameterGroup
from glotaran.parse.register import register_model
from .base_model import Model
from .dataset_descriptor import DatasetDescriptor
from .util import wrap_func_as_method
from .weight import Weight
MatrixFunction = Callable[[Type[DatasetDescriptor], xr.Dataset], Tuple[List[str], np.ndarray]]
"""A `MatrixFunction` calculates the matrix for a model."""
IndexDependentMatrixFunction = Callable[
[Type[DatasetDescriptor], xr.Dataset, Any],
Tuple[List[str], np.ndarray],
]
"""A `MatrixFunction` calculates the matrix for a model."""
GlobalMatrixFunction = Callable[
[Type[DatasetDescriptor], np.ndarray], Tuple[List[str], np.ndarray]
]
"""A `GlobalMatrixFunction` calculates the global matrix for a model."""
ConstrainMatrixFunction = Callable[
[Type[Model], ParameterGroup, List[str], np.ndarray, float],
Tuple[List[str], np.ndarray],
]
"""A `ConstrainMatrixFunction` applies constraints on a matrix."""
RetrieveClpFunction = Callable[
[
Type[Model],
ParameterGroup,
Dict[str, Union[List[str], List[List[str]]]],
Dict[str, Union[List[str], List[List[str]]]],
Dict[str, List[np.ndarray]],
Dict[str, xr.Dataset],
],
Dict[str, List[np.ndarray]],
]
"""A `RetrieveClpFunction` retrieves the full set of clp from a reduced set."""
FinalizeFunction = Callable[
[TypeVar("glotaran.analysis.problem.Problem"), Dict[str, xr.Dataset]], None
]
"""A `FinalizeFunction` gets called after optimization."""
PenaltyFunction = Callable[
[
Type[Model],
ParameterGroup,
Dict[str, Union[List[str], List[List[str]]]],
Dict[str, List[np.ndarray]],
Dict[str, Union[np.ndarray, List[np.ndarray]]],
Dict[str, xr.Dataset],
float,
],
np.ndarray,
]
"""A `PenaltyFunction` calculates additional penalties for the optimization."""
[docs]def model(
model_type: str,
attributes: Dict[str, Any] = None,
dataset_type: Type[DatasetDescriptor] = DatasetDescriptor,
megacomplex_type: Any = None,
matrix: Union[MatrixFunction, IndexDependentMatrixFunction] = None,
global_matrix: GlobalMatrixFunction = None,
model_dimension: str = None,
global_dimension: str = None,
has_matrix_constraints_function: Callable[[Type[Model]], bool] = None,
constrain_matrix_function: ConstrainMatrixFunction = None,
retrieve_clp_function: RetrieveClpFunction = None,
has_additional_penalty_function: Callable[[Type[Model]], bool] = None,
additional_penalty_function: PenaltyFunction = None,
finalize_data_function: FinalizeFunction = None,
grouped: Union[bool, Callable[[Type[Model]], bool]] = False,
index_dependent: Union[bool, Callable[[Type[Model]], bool]] = False,
) -> Callable:
"""The `@model` decorator is intended to be used on subclasses of :class:`glotaran.model.Model`.
It creates properties for the given attributes as well as functions to add access them. Also it
adds the functions (e.g. for `matrix`) to the model ensures they are added wrapped in a correct
way.
Parameters
----------
model_type : str
Human readable string used by the parser to identify the correct model.
attributes : Dict[str, Any], optional
A dictionary of attribute names and types. All types must be decorated with the
:func:`glotaran.model.model_attribute` decorator, by default None.
dataset_type : Type[DatasetDescriptor], optional
A subclass of :class:`DatasetDescriptor`, by default DatasetDescriptor
megacomplex_type : Any, optional
A class for the model megacomplexes. The class must be decorated with the
:func:`glotaran.model.model_attribute` decorator, by default None
matrix : Union[MatrixFunction, IndexDependentMatrixFunction], optional
A function to calculate the matrix for the model, by default None
global_matrix : GlobalMatrixFunction, optional
A function to calculate the global matrix for the model, by default None
model_dimension : str, optional
The name of model matrix row dimension, by default None
global_dimension : str, optional
The name of model global matrix row dimension, by default None
has_matrix_constraints_function : Callable[[Type[Model]], bool], optional
True if the model as a constrain_matrix_function set, by default None
constrain_matrix_function : ConstrainMatrixFunction, optional
A function to constrain the global matrix for the model, by default None
retrieve_clp_function : RetrieveClpFunction, optional
A function to retrieve the full clp from the reduced, by default None
has_additional_penalty_function : Callable[[Type[Model]], bool], optional
True if model has a additional_penalty_function set, by default None
additional_penalty_function : PenaltyFunction, optional
A function to calculate additional penalties when optimizing the model, by default None
finalize_data_function : FinalizeFunction, optional
A function to finalize data after optimization, by default None
grouped : Union[bool, Callable[[Type[Model]], bool]], optional
True if model described a grouped problem, by default False
index_dependent : Union[bool, Callable[[Type[Model]], bool]], optional
True if model described a index dependent problem, by default False
Returns
-------
Callable
Returns a decorated model function
Raises
------
ValueError
If model implements meth:`has_matrix_constraints_function` but not
meth:`constrain_matrix_function` and meth:`retrieve_clp_function`
ValueError
If model implements meth:`has_additional_penalty_function` but not
meth:`additional_penalty_function`
"""
def decorator(cls):
setattr(cls, "_model_type", model_type)
setattr(cls, "finalize_data", finalize_data_function)
if has_matrix_constraints_function:
if not constrain_matrix_function:
raise ValueError(
"Model implements `has_matrix_constraints_function` "
"but not `constrain_matrix_function`"
)
if not retrieve_clp_function:
raise ValueError(
"Model implements `has_matrix_constraints_function` "
"but not `retrieve_clp_function`"
)
has_c_mat = wrap_func_as_method(cls, name="has_matrix_constraints_function")(
has_matrix_constraints_function
)
c_mat = wrap_func_as_method(cls, name="constrain_matrix_function")(
constrain_matrix_function
)
r_clp = wrap_func_as_method(cls, name="retrieve_clp_function")(retrieve_clp_function)
setattr(cls, "has_matrix_constraints_function", has_c_mat)
setattr(cls, "constrain_matrix_function", c_mat)
setattr(cls, "retrieve_clp_function", r_clp)
else:
setattr(cls, "has_matrix_constraints_function", None)
setattr(cls, "constrain_matrix_function", None)
setattr(cls, "retrieve_clp_function", None)
if has_additional_penalty_function:
if not additional_penalty_function:
raise ValueError(
"Model implements `has_additional_penalty_function`"
"but not `additional_penalty_function`"
)
has_pen = wrap_func_as_method(cls, name="has_additional_penalty_function")(
has_additional_penalty_function
)
pen = wrap_func_as_method(cls, name="additional_penalty_function")(
additional_penalty_function
)
setattr(cls, "additional_penalty_function", pen)
setattr(cls, "has_additional_penalty_function", has_pen)
else:
setattr(cls, "has_additional_penalty_function", None)
setattr(cls, "additional_penalty_function", None)
if not callable(grouped):
def group_fun(model):
return grouped
else:
group_fun = grouped
setattr(cls, "grouped", group_fun)
if not callable(index_dependent):
def index_dep_fun(model):
return index_dependent
else:
index_dep_fun = index_dependent
setattr(cls, "index_dependent", index_dep_fun)
mat = wrap_func_as_method(cls, name="matrix")(matrix)
mat = staticmethod(mat)
setattr(cls, "matrix", mat)
if model_dimension is None:
raise ValueError(f"Model dimension not specified for model {model_type}")
setattr(cls, "model_dimension", model_dimension)
if global_matrix:
g_mat = wrap_func_as_method(cls, name="global_matrix")(global_matrix)
g_mat = staticmethod(g_mat)
setattr(cls, "global_matrix", g_mat)
else:
setattr(cls, "global_matrix", None)
if global_dimension is None:
raise ValueError(f"Global dimension not specified for model {model_type}")
setattr(cls, "global_dimension", global_dimension)
if not hasattr(cls, "_glotaran_model_attributes"):
setattr(cls, "_glotaran_model_attributes", {})
else:
setattr(
cls,
"_glotaran_model_attributes",
getattr(cls, "_glotaran_model_attributes").copy(),
)
# We add the standard attributes here.
attributes["dataset"] = dataset_type
attributes["megacomplex"] = megacomplex_type
attributes["weights"] = Weight
# Set annotations and methods for attributes
for attr_name, attr_type in attributes.items():
# store for internal lookups
getattr(cls, "_glotaran_model_attributes")[attr_name] = None
# create and attach the property to class
attr_prop = _create_property_for_attribute(cls, attr_name, attr_type)
setattr(cls, attr_name, attr_prop)
# properties with labels are implemented as dicts, whereas properties
# without as arrays. Thus the need different setters.
if getattr(attr_type, "_glotaran_has_label"):
get_item = _create_get_func(cls, attr_name, attr_type)
setattr(cls, get_item.__name__, get_item)
set_item = _create_set_func(cls, attr_name, attr_type)
setattr(cls, set_item.__name__, set_item)
else:
add_item = _create_add_func(cls, attr_name, attr_type)
setattr(cls, add_item.__name__, add_item)
init = _create_init_func(cls, attributes)
setattr(cls, "__init__", init)
register_model(model_type, cls)
return cls
return decorator
def _create_init_func(cls, attributes):
@wrap_func_as_method(cls)
def __init__(self):
for attr_name, attr_item in attributes.items():
if getattr(attr_item, "_glotaran_has_label"):
setattr(self, f"_{attr_name}", {})
else:
setattr(self, f"_{attr_name}", [])
super(cls, self).__init__()
return __init__
def _create_add_func(cls, name, type):
@wrap_func_as_method(cls, name=f"add_{name}")
def add_item(self, item: type):
f"""Adds an `{type.__name__}` object.
Parameters
----------
item :
The `{type.__name__}` item.
"""
if not isinstance(item, type) and (
not hasattr(type, "_glotaran_model_attribute_typed")
or not isinstance(item, tuple(type._glotaran_model_attribute_types.values()))
):
raise TypeError
getattr(self, f"_{name}").append(item)
return add_item
def _create_get_func(cls, name, type):
@wrap_func_as_method(cls, name=f"get_{name}")
def get_item(self, label: str) -> type:
f"""
Returns the `{type.__name__}` object with the given label.
Parameters
----------
label :
The label of the `{type.__name__}` object.
"""
return getattr(self, f"_{name}")[label]
return get_item
def _create_set_func(cls, name, type):
@wrap_func_as_method(cls, name=f"set_{name}")
def set_item(self, label: str, item: type):
f"""
Sets the `{type.__name__}` object with the given label with to the item.
Parameters
----------
label :
The label of the `{type.__name__}` object.
item :
The `{type.__name__}` item.
"""
if not isinstance(item, type) and (
not hasattr(type, "_glotaran_model_attribute_typed")
or not isinstance(item, tuple(type._glotaran_model_attribute_types.values()))
):
raise TypeError
getattr(self, f"_{name}")[label] = item
return set_item
def _create_property_for_attribute(cls, name, type):
return_type = Dict[str, type] if hasattr(type, "_glotaran_has_label") else List[type]
doc_type = "dictionary" if hasattr(type, "_glotaran_has_label") else "list"
@property
@wrap_func_as_method(cls, name=f"{name}")
def attribute(self) -> return_type:
f"""A {doc_type} containing {type.__name__}"""
return getattr(self, f"_{name}")
return attribute