Source code for cinnabar.estimators

"""
Estimators
==========

Estimators compute absolute free energy values from a set of relative
measurements stored in an FEMap.

"""

import abc
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import networkx as nx
import numpy as np
from openff.units import Quantity

from cinnabar import stats
from cinnabar.measurements import Measurement, ReferenceState

if TYPE_CHECKING:
    from cinnabar.femap import FEMap  # pragma: no cover


[docs] @dataclass class EstimatorResult: """Simple base class for estimator results with provenance fields. This is the base class for all estimator results and should be used to store any additional metadata that is not appropriate for the returned measurements but might be useful for downstream analysis, like the covariance matrix of the MLE estimator. Subclasses should define specific typed fields for this metadata to help IDEs. Attributes ---------- estimator : str The class name of the estimator that produced this result, e.g. ``"MLEEstimator"``. Set automatically by ``Estimator.estimate``. source : str The composed source label stamped on the output measurements, e.g. ``"MLE"`` for a single-source map or ``"MLE(openff-sage)"`` when multiple input sources are present. Set automatically by ``Estimator.estimate``. """ estimator: str = field(default="", init=False) source: str = field(default="", init=False)
[docs] @dataclass class MLEEstimatorResult(EstimatorResult): """Results data produced by `MLEEstimator`. Attributes ---------- covariance_matrix : np.ndarray, shape (N, N) Full MLE covariance matrix. Entry ``[i, j]`` is the covariance between the free energy estimates of ligands ``i`` and ``j``. ligand_order : list Ordered list of ligand labels whose index maps to rows/columns of ``covariance_matrix`` """ covariance_matrix: np.ndarray ligand_order: list
[docs] class Estimator(abc.ABC): """Abstract base class for free-energy estimators. Subclasses must implement the ``_estimate`` method and set a ``source`` class attribute that is used as the ``source`` field on returned ``Measurement`` objects and as the key under which the ``EstimatorResult`` is stored on the FEMap. """ source: str @staticmethod def _check_weakly_connected(measurements: list[Measurement]) -> bool: """Check if the computational graph of the provided measurements is connected.""" g = nx.MultiGraph() for m in measurements: if m.computational and not isinstance(m.labelA, ReferenceState): g.add_edge(m.labelA, m.labelB) try: return nx.is_connected(g) except nx.NetworkXPointlessConcept: return False
[docs] def estimate(self, femap: "FEMap") -> dict[str, tuple[list[Measurement], EstimatorResult]]: """Run the estimator on the FEMap for each unique computational source. Parameters ---------- femap : FEMap The map to estimate from. Returns ------- dict[str, tuple[list[Measurement], EstimatorResult]] A dictionary mapping the *composed source label* to a ``(measurements, result)`` tuple. The composed label is ``"{estimator.source}({input_source})"`` when the FEMap contains more than one computational source (e.g. ``"MLE(openfe)"``), or just ``"{estimator.source}"`` when there is only one, so that single-source users never need to know the input source name. Notes ----- * Connectivity is checked per source before ``_estimate`` is called. * Experimental measurements are forwarded to every source so the estimator can use them to center predictions. * The estimates are stamped with a composed source label of the form ``"{estimator.source}({input_source})"`` when multiple computational sources are present, or just ``"{estimator.source}"`` when there is only one. """ measurements_by_source: dict[str, list[Measurement]] = defaultdict(list) experimental_measurements: list[Measurement] = [] for m in femap: if m.computational: measurements_by_source[m.source].append(m) else: experimental_measurements.append(m) multiple_sources = len(measurements_by_source) > 1 results = {} for input_source, comp_measurements in measurements_by_source.items(): if not self._check_weakly_connected(comp_measurements): raise ValueError(f"Computational results for source '{input_source}' are not fully connected") # Only compose the label when it is actually needed to disambiguate. # Single-source users can then call get_estimator_metadata("MLE") # without having to know or construct the input source name. composed_source = f"{self.source}({input_source})" if multiple_sources else self.source measurements, result = self._estimate( comp_measurements + experimental_measurements, source=composed_source, ) # Stamp provenance automatically so subclasses don't have to. result.estimator = type(self).__name__ result.source = composed_source results[composed_source] = (measurements, result) return results
@abc.abstractmethod def _estimate( self, measurements: list[Measurement], source: str, ) -> tuple[list[Measurement], EstimatorResult]: """Estimate absolute free energies from a list of measurements. Measurements can be a mix of computational and experimental relative and absolute free energy measurements. Absolute values should be used to center the results if possible. Parameters ---------- measurements : list[Measurement] A list of absolute and relative free energy measurements to estimate from this can include both computational and experimental values. source : str The composed source label to stamp on returned measurements and use as the key for storing the result on the FEMap. Returns ------- measurements : list[Measurement] Absolute free energy estimates to be added to the FEMap. result : EstimatorResult Estimator-specific intermediate data that cannot be reconstructed from the measurements alone. Raises ------ ValueError If the estimator cannot be applied (e.g. the graph is not connected, or there are duplicate edges). """ ... # pragma: no cover
[docs] class MLEEstimator(Estimator): """Maximum-likelihood estimator (MLE) for absolute free energies. Uses the MLE solver from :mod:`cinnabar.stats` to compute the most probable set of absolute free energies consistent with the relative measurements stored in the map. Parameters ---------- source : str, default "MLE" Label attached to the returned measurements and used as the storage key on the FEMap. Defaults to MLE. Notes ----- * Requires the computational sub-graph to be weakly connected. * Cannot handle multiple edges between the same pair of nodes; combine replicates into a single estimate before calling this estimator. """ def __init__(self, source: str = "MLE"): self.source = source def _estimate( self, measurements: list[Measurement], source: str, ) -> tuple[list[Measurement], MLEEstimatorResult]: """Run MLE on the measurements and return the estimated DG values. Parameters ---------- measurements : list[Measurement] Relative computational edges plus any experimental or computational absolute measurements for a single source. source : str The composed source label to stamp on returned measurements and use as the key for storing the result on the FEMap. Returns ------- measurements : list[Measurement] One absolute-DG ``Measurement`` per ligand, plus an anchor connecting the MLE reference state to the global ``ReferenceState``. result : MLEEstimatorResult Contains :attr:`~MLEEstimatorResult.covariance_matrix` and :attr:`~MLEEstimatorResult.ligand_order`. """ # TODO: replace stats.mle call with a self-contained implementation g, u = _build_graph_from_measurements(measurements) f_i_calc, C_calc = stats.mle(g, factor="calc_DDG") variance = np.diagonal(C_calc) ** 0.5 ref = ReferenceState(label=source) ligand_order = list(g.nodes) out_measurements: list[Measurement] = [] for n, f_i, df_i in zip(ligand_order, f_i_calc, variance): out_measurements.append( Measurement( labelA=ref, labelB=n, DG=f_i * u, uncertainty=df_i * u, computational=True, source=source, ) ) # anchor the estimator reference state to the global reference state out_measurements.append( Measurement( labelA=ReferenceState(), labelB=ref, DG=Quantity(0.1, units=u), uncertainty=Quantity(0.0, units=u), computational=True, source=source, ) ) return out_measurements, MLEEstimatorResult( covariance_matrix=C_calc, ligand_order=ligand_order, )
def _build_graph_from_measurements( measurements: list[Measurement], ) -> tuple[nx.DiGraph, object]: """Build a legacy graph from the list of measurements for use in the MLE method, this is copied over from the to_legacy_graph method of FEMap. Parameters ---------- measurements : list[Measurement] Mix of relative computational and absolute experimental measurements. Returns ------- g : nx.DiGraph Input graph ready for stats.mle u : unit The unit shared by all measurements (validated to be consistent). Raises ------ ValueError If measurements have mixed units or duplicate computational edges exist between the same pair of nodes. """ if not measurements: raise ValueError("No measurements provided") units = {m.DG.u for m in measurements} if len(units) > 1: raise ValueError(f"All measurements must share the same units before running an estimator. Found: {units}") u = next(iter(units)) g = nx.DiGraph() edges_seen: list[tuple] = [] for m in measurements: if not m.computational: continue if isinstance(m.labelA, ReferenceState): continue # cast to string as hashable does not support < > comparisons edge_name = tuple(sorted([str(m.labelA), str(m.labelB)])) if edge_name in edges_seen: raise ValueError( f"Multiple edges detected between nodes {m.labelA} and {m.labelB}. " "MLE cannot be performed on graphs with multiple edges between the " "same nodes. The results should be combined into a single estimate " "and uncertainty before performing MLE. " "See https://cinnabar.openfree.energy/en/latest/concepts/estimators.html" "#limitations for more details." ) g.add_edge( m.labelA, m.labelB, calc_DDG=m.DG.magnitude, calc_dDDG=m.uncertainty.magnitude, ) edges_seen.append(edge_name) # annotate nodes with experimental absolute values for m in measurements: if m.computational: continue if not isinstance(m.labelA, ReferenceState): continue node = m.labelB if node not in g.nodes: continue g.nodes[node]["exp_DG"] = m.DG.magnitude g.nodes[node]["exp_dDG"] = m.uncertainty.magnitude g.nodes[node]["name"] = node # infer experimental DDG for edges where both endpoints have absolute data for A, B, d in g.edges(data=True): try: DG_A = g.nodes[A]["exp_DG"] dDG_A = g.nodes[A]["exp_dDG"] DG_B = g.nodes[B]["exp_DG"] dDG_B = g.nodes[B]["exp_dDG"] except KeyError: continue d["exp_DDG"] = DG_B - DG_A d["exp_dDDG"] = (dDG_A**2 + dDG_B**2) ** 0.5 return g, u