# -*- coding: utf-8 -*-
# This code is part of Qiskit.
#
# (C) Copyright IBM 2020.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name,no-member,attribute-defined-outside-init
r"""
Solver functions.
"""
from typing import Optional, Union, Callable, Tuple, TypeVar
from warnings import warn
from scipy.integrate import OdeSolver
from scipy.integrate._ivp.ivp import OdeResult
from qiskit import QiskitError
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.models import (
BaseGeneratorModel,
GeneratorModel,
LindbladModel,
)
from qiskit_dynamics.models.hamiltonian_model import HamiltonianModel
from .solver_utils import is_lindblad_model_not_vectorized
from .fixed_step_solvers import (
RK4_solver,
jax_RK4_solver,
scipy_expm_solver,
lanczos_diag_solver,
jax_lanczos_diag_solver,
jax_expm_solver,
jax_RK4_parallel_solver,
jax_expm_parallel_solver,
)
from .scipy_solve_ivp import scipy_solve_ivp, SOLVE_IVP_METHODS
from .jax_odeint import jax_odeint
from .diffrax_solver import diffrax_solver
ODE_METHODS = (
["RK45", "RK23", "BDF", "DOP853", "Radau", "LSODA"] # scipy solvers
+ ["RK4"] # fixed step solvers
+ ["jax_odeint", "jax_RK4"] # jax solvers
)
LMDE_METHODS = [
"scipy_expm",
"lanczos_diag",
"jax_lanczos_diag",
"jax_expm",
"jax_expm_parallel",
"jax_RK4_parallel",
]
# diffrax solver type placeholder
# pylint: disable=typevar-name-mismatch
DiffraxAbstractSolver = TypeVar("AbstractSolver")
def _is_jax_method(method: any) -> bool:
"""Check if method is a jax solver method."""
if method in [
"jax_odeint",
"jax_RK4",
"jax_expm",
"jax_expm_parallel",
"jax_RK4_parallel",
"jax_lanczos_diag",
]:
return True
# only other jax methods are diffrax methods
return _is_diffrax_method(method)
def _is_diffrax_method(method: any) -> bool:
"""Check if method is a diffrax method."""
try:
from diffrax import AbstractSolver
return isinstance(method, AbstractSolver)
except ImportError:
return False
def _lanczos_validation(
rhs: Union[Callable, BaseGeneratorModel],
t_span: ArrayLike,
y0: ArrayLike,
k_dim: int,
):
"""Validation checks to run lanczos based solvers."""
t_span = unp.asarray(t_span)
y0 = unp.asarray(y0)
if isinstance(rhs, BaseGeneratorModel):
if not isinstance(rhs, HamiltonianModel):
raise QiskitError(
"""Lanczos solver can only be used for HamiltonianModel or function-based
anti-Hermitian generators."""
)
if "sparse" not in rhs.array_library:
warn(
"""lanczos_diag should be used with a generator in sparse mode
for better performance.""",
category=Warning,
stacklevel=2,
)
dim = rhs(t_span[0]).shape[0]
if k_dim > dim:
raise QiskitError("k_dim can be no larger than the dimension of the generator.")
if y0.ndim not in [1, 2]:
raise QiskitError("y0 must be 1d or 2d.")
[docs]
def solve_ode(
rhs: Union[Callable, BaseGeneratorModel],
t_span: ArrayLike,
y0: ArrayLike,
method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853",
t_eval: Optional[ArrayLike] = None,
**kwargs,
) -> OdeResult:
r"""General interface for solving Ordinary Differential Equations (ODEs).
ODEs are differential equations of the form
.. math::
\dot{y}(t) = f(t, y(t)),
where :math:`f` is a callable function and the state :math:`y(t)` is an
arbitrarily-shaped complex array.
The ``method`` argument exposes a variety of underlying ODE solvers. Optional
arguments for any of the solver routines can be passed via ``kwargs``.
Available methods are:
- ``scipy.integrate.solve_ivp`` - supports methods
``['RK45', 'RK23', 'BDF', 'DOP853', 'Radau', 'LSODA']`` or by passing a valid
``scipy`` :class:`OdeSolver` instance.
- ``'RK4'``: A fixed-step 4th order Runge-Kutta solver.
Requires additional kwarg ``max_dt``, indicating the maximum step
size to take. This solver will break integration periods into even
sub-intervals no larger than ``max_dt``, and step over each sub-interval
using the standard 4th order Runge-Kutta integration rule.
- ``'jax_RK4'``: JAX backend implementation of ``'RK4'`` method.
- ``'jax_odeint'``: Calls ``jax.experimental.ode.odeint`` variable step solver.
- ``diffrax.diffeqsolve`` - a JAX solver function, called by passing ``method``
as a valid ``diffrax.AbstractSolver`` instance. Requires the ``diffrax`` library.
Results are returned as a :class:`OdeResult` object.
Args:
rhs: RHS function :math:`f(t, y)`.
t_span: ``Tuple`` or ``list`` of initial and final time.
y0: State at initial time.
method: Solving method to use.
t_eval: Times at which to return the solution. Must lie within ``t_span``. If unspecified,
the solution will be returned at the points in ``t_span``.
**kwargs: Additional arguments to pass to the solver.
Returns:
OdeResult: Results object.
Raises:
QiskitError: If specified method does not exist.
"""
if method not in ODE_METHODS and not (
(isinstance(method, type) and (issubclass(method, OdeSolver))) or _is_diffrax_method(method)
):
raise QiskitError("Method " + str(method) + " not supported by solve_ode.")
y0 = unp.asarray(y0)
if isinstance(rhs, BaseGeneratorModel):
_, solver_rhs, y0, model_in_frame_basis = setup_generator_model_rhs_y0_in_frame_basis(
rhs, y0
)
else:
solver_rhs = rhs
# solve the problem using specified method
if method in SOLVE_IVP_METHODS or (isinstance(method, type) and issubclass(method, OdeSolver)):
results = scipy_solve_ivp(solver_rhs, t_span, y0, method, t_eval=t_eval, **kwargs)
elif isinstance(method, str) and method == "RK4":
results = RK4_solver(solver_rhs, t_span, y0, t_eval=t_eval, **kwargs)
elif isinstance(method, str) and method == "jax_RK4":
results = jax_RK4_solver(solver_rhs, t_span, y0, t_eval=t_eval, **kwargs)
elif isinstance(method, str) and method == "jax_odeint":
results = jax_odeint(solver_rhs, t_span, y0, t_eval=t_eval, **kwargs)
elif _is_diffrax_method(method):
results = diffrax_solver(solver_rhs, t_span, y0, method=method, t_eval=t_eval, **kwargs)
# convert results out of frame basis if necessary
if isinstance(rhs, BaseGeneratorModel):
if not model_in_frame_basis:
results.y = results_y_out_of_frame_basis(rhs, results.y, y0.ndim)
# convert model back to original basis
rhs.in_frame_basis = model_in_frame_basis
return results
[docs]
def solve_lmde(
generator: Union[Callable, BaseGeneratorModel],
t_span: ArrayLike,
y0: ArrayLike,
method: Optional[Union[str, OdeSolver, DiffraxAbstractSolver]] = "DOP853",
t_eval: Optional[ArrayLike] = None,
**kwargs,
) -> OdeResult:
r"""General interface for solving Linear Matrix Differential Equations (LMDEs)
in standard form.
LMDEs in standard form are differential equations of the form:
.. math::
\dot{y}(t) = G(t)y(t).
where :math:`G(t)` is a square matrix valued-function called the *generator*, and :math:`y(t)`
is an array of appropriate shape.
Thus function accepts :math:`G(t)` as a ``qiskit_dynamics`` model class, or as an arbitrary
callable.
.. note::
Not all model classes are by-default in standard form. E.g.
:class:`~qiskit_dynamics.models.LindbladModel` represents an LMDE which is not typically
written in standard form. As such, using LMDE-specific methods with this generator requires
the equation to be vectorized.
The ``method`` argument exposes solvers specialized to both LMDEs, as well as general ODE
solvers. If the method is not specific to LMDEs, the problem will be passed to
:meth:`~qiskit_dynamics.solve_ode` by automatically setting up the RHS function :math:`f(t, y) =
G(t)y`.
Optional arguments for any of the solver routines can be passed via ``kwargs``. Available
LMDE-specific methods are:
- ``'scipy_expm'``: A fixed-step matrix-exponential solver using ``scipy.linalg.expm``. Requires
additional kwarg ``max_dt`` indicating the maximum step size to take. This solver will break
integration periods into even sub-intervals no larger than ``max_dt`` and solve over each
sub-interval. The optional kwarg ``magnus_order`` controls the integration rule: if
``magnus_order==1``, the generator is sampled at the interval midpoint and exponentiated, and
if ``magnus_order==2`` or ``magnus_order==3``, higher-order exponentiation rules are adopted
from :footcite:`blanes_magnus_2009`. The ``magnus_order`` parameter defaults to ``1``.
- ``'lanczos_diag'``: A fixed-step matrix-exponential solver, similar to ``'scipy_expm'`` but
restricted to anti-Hermitian generators. The matrix exponential is performed by diagonalizing
an approximate projection of the generator to a small subspace (the Krylov Subspace), obtained
via the Lanczos algorithm, and then exponentiating the eigenvalues. Requires additional kwargs
``max_dt`` and ``k_dim`` indicating the maximum step size to take and Krylov subspace
dimension, respectively. ``k_dim`` acts as an adjustable accuracy parameter and can be no
larger than the dimension of the generator. The method is recommended for sparse systems with
large dimension.
- ``'jax_lanczos_diag'``: JAX implementation of ``'lanczos_diag'``, with the same arguments and
behaviour. Note that this method contains calls to ``jax.numpy.eigh``, which may have limited
validity when automatically differentiated.
- ``'jax_expm'``: JAX-implemented version of ``'scipy_expm'``, with the same arguments and
behaviour. Note that this method cannot be used for a model using a sparse array library.
- ``'jax_expm_parallel'``: Same as ``'jax_expm'``, however all loops are implemented using
parallel operations. I.e. all matrix-exponentials for taking a single step are computed in
parallel using ``jax.vmap``, and are subsequently multiplied together in parallel using
``jax.lax.associative_scan``. This method is only recommended for use with GPU execution. Note
that this method cannot be used for a model using a sparse array library.
- ``'jax_RK4_parallel'``: 4th order Runge-Kutta fixed step solver. Under the assumption of the
structure of an LMDE, utilizes the same parallelization approach as ``'jax_expm_parallel'``,
however the single step rule is the standard 4th order Runge-Kutta rule, rather than
matrix-exponentiation. Requires and utilizes the ``max_dt`` kwarg in the same manner as
``method='scipy_expm'``. This method is only recommended for use with GPU execution.
Results are returned as a :class:`OdeResult` object.
Args:
generator: Representation of generator function :math:`G(t)`.
t_span: ``Tuple`` or `list` of initial and final time.
y0: State at initial time.
method: Solving method to use.
t_eval: Times at which to return the solution. Must lie within ``t_span``. If unspecified,
the solution will be returned at the points in ``t_span``.
**kwargs: Additional arguments to pass to the solver.
Returns:
OdeResult: Results object.
Raises:
QiskitError: If specified method does not exist,
if dimension of ``y0`` is incompatible with generator dimension, or if an
LMDE-specific method is passed with a LindbladModel.
Additional Information:
While all :class:`~qiskit_dynamics.models.BaseGeneratorModel` subclasses represent LMDEs,
they are not all in standard form by defualt. Using an LMDE-specific models like
:class:`~qiskit_dynamics.models.LindbladModel` requires first setting the model to be
vectorized.
"""
# delegate to solve_ode if necessary
if (
method in ODE_METHODS
or (isinstance(method, type) and (issubclass(method, OdeSolver)))
or _is_diffrax_method(method)
):
if isinstance(generator, BaseGeneratorModel):
rhs = generator
else:
# treat generator as a function
def rhs(t, y):
return generator(t) @ y
return solve_ode(rhs, t_span, y0, method=method, t_eval=t_eval, **kwargs)
# raise error if neither an ODE_METHOD or an LMDE_METHOD
if method not in LMDE_METHODS:
raise QiskitError(f"Method {method} not supported by solve_lmde.")
# lmde-specific methods can't be used with LindbladModel unless vectorized
if is_lindblad_model_not_vectorized(generator):
raise QiskitError(
"LMDE-specific methods with LindbladModel requires setting a vectorized=True."
)
y0 = unp.asarray(y0)
# setup generator and rhs functions to pass to numerical methods
if isinstance(generator, BaseGeneratorModel):
solver_generator, _, y0, model_in_frame_basis = setup_generator_model_rhs_y0_in_frame_basis(
generator, y0
)
else:
solver_generator = generator
if method == "scipy_expm":
results = scipy_expm_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs)
elif "lanczos_diag" in method:
_lanczos_validation(generator, t_span, y0, kwargs["k_dim"])
if method == "lanczos_diag":
results = lanczos_diag_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs)
elif method == "jax_lanczos_diag":
results = jax_lanczos_diag_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs)
elif method == "jax_expm":
if isinstance(generator, BaseGeneratorModel) and "sparse" in generator.array_library:
raise QiskitError("jax_expm cannot be used with a generator in sparse mode.")
results = jax_expm_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs)
elif method == "jax_expm_parallel":
results = jax_expm_parallel_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs)
elif method == "jax_RK4_parallel":
results = jax_RK4_parallel_solver(solver_generator, t_span, y0, t_eval=t_eval, **kwargs)
# convert results to correct basis if necessary
if isinstance(generator, BaseGeneratorModel):
if not model_in_frame_basis:
results.y = results_y_out_of_frame_basis(generator, results.y, y0.ndim)
generator.in_frame_basis = model_in_frame_basis
return results
def setup_generator_model_rhs_y0_in_frame_basis(
generator_model: BaseGeneratorModel, y0: ArrayLike
) -> Tuple[Callable, Callable, ArrayLike]:
"""Helper function for setting up a subclass of
:class:`~qiskit_dynamics.models.BaseGeneratorModel` to be solved in the frame basis.
Note: this function modifies ``generator_model`` to function in the frame basis.
Args:
generator_model: Subclass of :class:`~qiskit_dynamics.models.BaseGeneratorModel`.
y0: Initial state.
Returns:
Callable for generator in frame basis, callable for RHS in frame basis, y0
in frame basis, and boolean indicating whether model was already specified in frame basis.
"""
model_in_frame_basis = generator_model.in_frame_basis
# if model not specified in frame basis, transform initial state into frame basis
if not model_in_frame_basis:
if isinstance(generator_model, LindbladModel) and generator_model.vectorized:
if generator_model.rotating_frame.frame_basis is not None:
y0 = generator_model.rotating_frame.vectorized_frame_basis_adjoint @ y0
elif isinstance(generator_model, LindbladModel):
y0 = generator_model.rotating_frame.operator_into_frame_basis(y0)
elif isinstance(generator_model, GeneratorModel):
y0 = generator_model.rotating_frame.state_into_frame_basis(y0)
# set model to operator in frame basis
generator_model.in_frame_basis = True
# define rhs functions in frame basis
def generator(t):
return generator_model(t)
def rhs(t, y):
return generator_model(t, y)
return generator, rhs, y0, model_in_frame_basis
def results_y_out_of_frame_basis(
generator_model: BaseGeneratorModel, results_y: ArrayLike, y0_ndim: int
) -> ArrayLike:
"""Convert the results of a simulation for :class:`~qiskit_dynamics.models.BaseGeneratorModel`
out of the frame basis.
Args:
generator_model: Subclass of :class:`~qiskit_dynamics.models.BaseGeneratorModel`.
results_y: Array whose first index corresponds to the evaluation points of the state
for the results of ``solve_lmde`` or ``solve_ode``.
y0_ndim: Number of dimensions of initial state.
Returns:
Callable for generator in frame basis, Callable for RHS in frame basis, and y0
transformed to frame basis.
"""
# for left multiplication cases, if number of input dimensions is 1
# vectorized basis transformation requires transposing before and after
if y0_ndim == 1:
results_y = results_y.T
if isinstance(generator_model, LindbladModel) and generator_model.vectorized:
if generator_model.rotating_frame.frame_basis is not None:
results_y = generator_model.rotating_frame.vectorized_frame_basis @ results_y
elif isinstance(generator_model, LindbladModel):
results_y = generator_model.rotating_frame.operator_out_of_frame_basis(results_y)
else:
results_y = generator_model.rotating_frame.state_out_of_frame_basis(results_y)
if y0_ndim == 1:
results_y = results_y.T
return results_y