Source code for ffsim.protocols.trace_protocol
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Trace protocol."""
from __future__ import annotations
from typing import Any, Protocol
import numpy as np
[docs]
class SupportsTrace(Protocol):
"""A linear operator whose trace can be computed."""
[docs]
def _trace_(self, norb: int, nelec: int | tuple[int, int]) -> float:
"""Return the trace of the linear operator.
Args:
norb: The number of spatial orbitals.
nelec: The number of alpha and beta electrons.
Returns:
The trace of the linear operator.
"""
[docs]
def trace(obj: Any, norb: int, nelec: int | tuple[int, int]) -> float:
"""Return the trace of the linear operator."""
method = getattr(obj, "_trace_", None)
if method is not None:
return method(norb=norb, nelec=nelec)
method = getattr(obj, "_diag_", None)
if method is not None:
return np.sum(method(norb=norb, nelec=nelec))
raise TypeError(
f"Could not compute trace of object of type {type(obj)}.\n"
"The object did not have a _trace_ method that returned the trace, or "
"a _diag_ method that returned its diagonal entries."
)