# This code is part of Qiskit.
#
# (C) Copyright IBM 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=invalid-name, inconsistent-return-statements
"""Functions for performing the Rotating Wave Approximation
on Model classes."""
from typing import List, Optional, Union
import numpy as np
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics.arraylias.alias import ArrayLike, _to_dense, _to_dense_list
from qiskit_dynamics.models import (
    BaseGeneratorModel,
    GeneratorModel,
    HamiltonianModel,
    LindbladModel,
    RotatingFrame,
)
from qiskit_dynamics.signals import SignalSum, Signal, SignalList
[docs]
def rotating_wave_approximation(
    model: BaseGeneratorModel, cutoff_freq: float, return_signal_map: Optional[bool] = False
) -> BaseGeneratorModel:
    r"""Construct a new model by performing the rotating wave approximation with a given cutoff
    frequency. The outputs of this function can be used in JAX-transformable functions, however this
    function itself cannot (see below).
    Performs elementwise rotating wave approximation (RWA) with cutoff frequency ``cutoff_freq`` on
    each operator in a model, returning a new model. The new model contains a modified list of
    signal coefficients, and setting the optional argument ``return_signal_map=True`` results in
    the additional return of the function ``f`` which maps the signals of the input model to those
    of the output RWA model, such that the code blocks:
    .. code-block:: python
        model.signals = new_signals
        rwa_model = rotating_wave_approximation(model, cutoff_freq)
    and
    .. code-block:: python
        rwa_model, f = rotating_wave_approximation(model, cutoff_freq, return_signal_map=True)
        rwa_model.signals = f(new_signals)
    result in an ``rwa_model`` with the same updated signals.
    .. note::
        The ``rotating_wave_approximation`` function itself cannot be included in a function
        to-be JAX-transformed, however the resulting model and ``signal_map`` can. For example,
        the following function is **not** JAX-transformable:
        .. code-block:: python
            def function_with_rwa(t):
                operators = ...
                signals = ...
                model = GeneratorModel(operators=operators, signals=signals)
                rwa_model = rotating_wave_approximation(model, cutoff_freq)
                return rwa_model(t)
        Whereas, defining:
        .. code-block:: python
            rwa_model, signal_map = rotating_wave_approximation(
                model,
                cutoff_freq,
                return_signal_map=True
            )
        The following function **is** JAX-transformable:
        .. code-block:: python
            def jax_transformable_func(t):
                rwa_model.signals = signal_map(new_signals)
                return rwa_model(t)
        In this way, the outputs of ``rotating_wave_approximation`` can be used in JAX-transformable
        functions, however ``rotating_wave_approximation`` itself cannot.
    We now describe the formalism. When considering :math:`s_i(t) e^{-tF}G_ie^{tF}`, in the basis in
    which :math:`F` is diagonal, the :math:`(j, k)` element of :math:`G_i` has effective frequency
    :math:`\tilde\nu_{ijk}^\pm = \pm\nu_i + Im[-d_j+d_k]/2\pi`, where the :math:`\pm\nu_i` comes
    from expressing :math:`s_i(t) = Re[a_i(t)e^{2\pi i\nu_i t}] = a_i(t)e^{i(2\pi\nu_i t+\phi_i)}/2
    + c.c.` and the other term comes from the rotating frame. Define :math:`G_i^\pm` the matrix
    whose entries :math:`(G_i^\pm)_{jk}` are the entries of :math:`G_i` s.t.
    :math:`|\nu_{ijk}^\pm|<\nu_*` for some cutoff frequency :math:`\nu_*`. Then, after the RWA, we
    may write
    .. math::
        s_i(t)G_i \to G_i^+ a_ie^{i(2\pi \nu_i t+\phi_i)}/2
                        + G_i^- \overline{a_i}e^{-i(2\pi \nu_i t+\phi_i)}/2.
    When we regroup these to use only the real components of the signal, we find that
    .. math::
        s_i(t)G_i \to  s_i(t)(G_i^+ + G_i^-)/2 + s_i'(t)(iG_i^+-iG_i^-)
    where :math:`s_i'(t)` is a signal with the same frequency and amplitude as :math:`s_i`, but with
    a phase shift of :math:`\phi_i - \pi/2`.
    Args:
        model: The model to approximate.
        cutoff_freq: The cutoff frequency for the approximation.
        return_signal_map: Whether to also return the function for mapping pre-RWA signals to
            post-RWA signals.
    Returns:
        :class:`GeneratorModel` with twice as many terms, and, if ``return_signal_map``,
        also the function ``f``.
    Raises:
        ValueError: If the model has no signals.
    """
    n = model.dim
    frame_freqs = None
    if model.rotating_frame is None or model.rotating_frame.frame_diag is None:
        frame_freqs = np.zeros((n, n), dtype=complex)
    else:
        diag = model.rotating_frame.frame_diag
        diff_matrix = np.broadcast_to(diag, (n, n)) - np.broadcast_to(diag, (n, n)).T
        frame_freqs = diff_matrix.imag / (2 * np.pi)
    if model.rotating_frame.frame_diag is not None:
        frame_shift = np.diag(model.rotating_frame.frame_diag)
        if isinstance(model, (HamiltonianModel, LindbladModel)):
            frame_shift = 1j * frame_shift
    else:
        frame_shift = np.zeros((n, n), dtype=complex)
    if isinstance(model, GeneratorModel):
        if model.signals is None and model.operators is not None:
            raise ValueError("Model must have nontrivial signals to perform the RWA.")
        cur_drift = _to_dense(model._operator_collection.static_operator)
        if isinstance(model, HamiltonianModel) and cur_drift is not None:
            cur_drift = 1j * cur_drift
        if cur_drift is not None:
            cur_drift = cur_drift + frame_shift
            rwa_drift = cur_drift * (abs(frame_freqs) < cutoff_freq).astype(int)
            rwa_drift = model.rotating_frame.operator_out_of_frame_basis(rwa_drift)
        else:
            rwa_drift = None
        operators = _to_dense_list(model._operator_collection.operators)
        if isinstance(model, HamiltonianModel) and operators is not None:
            operators = 1j * operators
        rwa_operators = get_rwa_operators(
            operators,
            model.signals,
            model.rotating_frame,
            frame_freqs,
            cutoff_freq,
        )
        rwa_signals = get_rwa_signals(model.signals)
        # works for both GeneratorModel and HamiltonianModel
        rwa_model = model.__class__(
            static_operator=rwa_drift,
            operators=rwa_operators,
            signals=rwa_signals,
            rotating_frame=model.rotating_frame,
            in_frame_basis=model.in_frame_basis,
            array_library=model.array_library,
        )
        if return_signal_map:
            return rwa_model, get_rwa_signals
        return rwa_model
    elif isinstance(model, LindbladModel):
        if model.signals[0] is None and model.hamiltonian_operators is not None:
            raise ValueError("Model must have nontrivial Hamiltonian signals to perform the RWA.")
        if model.signals[1] is None and model.dissipator_operators is not None:
            raise ValueError("Model must have nontrivial dissipator signals to perform the RWA.")
        # static hamiltonian part
        cur_drift = _to_dense(model._operator_collection.static_hamiltonian) + frame_shift
        rwa_drift = cur_drift * (abs(frame_freqs) < cutoff_freq).astype(int)
        rwa_drift = model.rotating_frame.operator_out_of_frame_basis(rwa_drift)
        # static dissipator part
        cur_static_dis = _to_dense_list(model._operator_collection.static_dissipators)
        rwa_static_dis = None
        if cur_static_dis is not None:
            rwa_static_dis = []
            for op in cur_static_dis:
                op = unp.asarray(op)
                rwa_op = op * (abs(frame_freqs) < cutoff_freq).astype(int)
                rwa_op = model.rotating_frame.operator_out_of_frame_basis(rwa_op)
                rwa_static_dis.append(rwa_op)
        cur_ham_ops = _to_dense_list(model._operator_collection.hamiltonian_operators)
        cur_dis_ops = _to_dense_list(model._operator_collection.dissipator_operators)
        cur_ham_sig, cur_dis_sig = model.signals
        rwa_ham_ops = get_rwa_operators(
            cur_ham_ops, cur_ham_sig, model.rotating_frame, frame_freqs, cutoff_freq
        )
        rwa_ham_sig = get_rwa_signals(cur_ham_sig)
        rwa_dis_ops = get_rwa_operators(
            cur_dis_ops, cur_dis_sig, model.rotating_frame, frame_freqs, cutoff_freq
        )
        rwa_dis_sig = get_rwa_signals(cur_dis_sig)
        rwa_model = LindbladModel(
            static_hamiltonian=rwa_drift,
            hamiltonian_operators=rwa_ham_ops,
            hamiltonian_signals=rwa_ham_sig,
            static_dissipators=rwa_static_dis,
            dissipator_operators=rwa_dis_ops,
            dissipator_signals=rwa_dis_sig,
            rotating_frame=model.rotating_frame,
            in_frame_basis=model.in_frame_basis,
            array_library=model.array_library,
            vectorized=model.vectorized,
        )
        if return_signal_map:
            return rwa_model, lambda a: (get_rwa_signals(a[0]), get_rwa_signals(a[1]))
        return rwa_model 
def get_rwa_operators(
    current_ops: ArrayLike,
    current_sigs: SignalList,
    rotating_frame: RotatingFrame,
    frame_freqs: ArrayLike,
    cutoff_freq: float,
) -> ArrayLike:
    r"""Given a set of operators as a ``(k,n,n)`` array, a set of frequencies
    :math:`\operatorname{frame_freqs}_{jk} = \operatorname{Im}[-d_j+d_k]` where :math:`d_i` the
    :math:`i^{th}` eigenlvalue of the frame operator :math:`F`, the current signals of a model, and
    a cutoff frequency, returns the new operators and signals that should be passed to create a new
    Model class after the RWA.
    Args:
        current_ops: The current operator list, a ``(k,n,n)`` array.
        current_sigs: ``(k,)`` length :class:`SignalList`.
        rotating_frame: The current :class:`~RotatingFrame` object of the pre-RWA model
        frame_freqs: The effective frequencies of different matrix elements due to the conjugation
            by :math:`e^{\pm Ft}` in the rotating frame.
        cutoff_freq: The maximum frequency allowed under the RWA.
    Returns:
        SignaLList: ``(2k,n,n)`` array of new operators post RWA.
    """
    if current_ops is None:
        return None
    current_sigs = current_sigs.flatten()
    carrier_freqs = np.zeros(len(current_ops))
    for i, sig_sum in enumerate(current_sigs.components):
        sig = sig_sum.components[0]
        carrier_freqs[i] = sig.carrier_freq
    num_components = len(carrier_freqs)
    n = current_ops[0].shape[-1]
    frame_freqs = np.broadcast_to(frame_freqs, (num_components, n, n))
    carrier_freqs = carrier_freqs.reshape((num_components, 1, 1))
    pos_pass = np.abs(carrier_freqs + frame_freqs) < cutoff_freq
    pos_terms = current_ops * pos_pass.astype(int)  # G_i^+
    neg_pass = np.abs(-carrier_freqs + frame_freqs) < cutoff_freq
    neg_terms = current_ops * neg_pass.astype(int)  # G_i^-
    real_component = pos_terms / 2 + neg_terms / 2
    imag_component = 1j * pos_terms / 2 - 1j * neg_terms / 2
    return rotating_frame.operator_out_of_frame_basis(
        np.append(real_component, imag_component, axis=0)
    )
def get_rwa_signals(curr_signal_list: Union[List[Signal], SignalList]) -> SignalList:
    """Helper function that converts pre-RWA signals to post-RWA signals.
    Args:
        curr_signal_list: The pre-RWA signals.
    Returns:
        The post-RWA signals.
    """
    if curr_signal_list is None:
        return curr_signal_list
    real_signal_components = []
    imag_signal_components = []
    if not isinstance(curr_signal_list, SignalList):
        curr_signal_list = SignalList(curr_signal_list)
    curr_signal_list = curr_signal_list.flatten()
    for sig_sum in curr_signal_list.components:
        sig = sig_sum.components[0]
        real_signal_components.append(sig)
        imag_signal_components.append(
            SignalSum(Signal(sig.envelope, sig.carrier_freq, sig.phase - np.pi / 2))
        )
    return SignalList(real_signal_components + imag_signal_components)