How-to use Dyson and Magnus based solvers#
Warning
This is an advanced topic — utilizing perturbation-theory based solvers requires detailed
knowledge of the structure of the differential equations involved, as well as manual tuning of
the solver parameters. See the DysonSolver
and MagnusSolver
documentation
for API details. Also, see [1] for a detailed explanation of
the solvers, which varies and builds on the core idea introduced in
[2].
Note
The circumstances under which perturbative solvers outperform traditional solvers, and which parameter sets to use, is nuanced. Perturbative solvers executed with JAX are setup to use more parallelization within a single solver run than typical solvers, and thus it is circumstance-specific whether the trade-off between speed of a single run and resource consumption is advantageous. Due to the parallelized nature, the comparison of execution times demonstrated in this userguide are highly hardware-dependent.
In this tutorial we walk through how to use perturbation-theory based solvers. For information on
how these solvers work, see the DysonSolver
and MagnusSolver
class
documentation, as well as the perturbative expansion background information provided in
Time-dependent perturbation theory and multi-variable series expansions review.
We use a simple transmon model:
where:
\(N\), \(a\), and \(a^\dagger\) are, respectively, the number, annihilation, and creation operators.
\(\nu\) is the qubit frequency and \(r\) is the drive strength.
\(s(t)\) is the drive signal, which we will take to be on resonance with envelope \(f(t) = A \frac{4t (T - t)}{T^2}\) for a given amplitude \(A\) and total time \(T\).
We will walk through the following steps:
Configure JAX.
Construct the model.
How-to construct and simulate using the Dyson-based perturbative solver.
Simulate using a traditional ODE solver, comparing speed.
How-to construct and simulate using the Magnus-based perturbative solver.
1. Configure JAX#
First, configure JAX to run on CPU in 64 bit mode. See the userguide on using JAX for a more detailed explanation of how to work with JAX in Qiskit Dynamics.
# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)
# tell JAX we are using CPU if using a system without a GPU
jax.config.update('jax_platform_name', 'cpu')
2. Construct the model#
First, we construct the model described in the introduction. We use a relatively high dimension for the oscillator system state space to accentuate the speed difference between the perturbative solvers and the traditional ODE solver. The higher dimensionality introduces higher frequencies into the model, which will slow down both the ODE solver and the initial construction of the perturbative solver. However after the initial construction, the higher frequencies in the model have no impact on the perturbative solver speed.
import numpy as np
dim = 10 # Oscillator dimension
v = 5. # Transmon frequency in GHz
anharm = -0.33 # Transmon anharmonicity in GHz
r = 0.02 # Transmon drive coupling in GHz
# Construct cavity operators
a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = np.diag(np.sqrt(np.arange(1, dim)), -1)
N = np.diag(np.arange(dim))
# Static part of Hamiltonian
static_hamiltonian = 2 * np.pi * v * N + np.pi * anharm * N * (N - np.eye(dim))
# Drive term of Hamiltonian
drive_hamiltonian = 2 * np.pi * r * (a + adag)
# total simulation time
T = 1. / r
# Drive envelope function
envelope_func = lambda t: t * (T - t) / (T**2 / 4)
3. How-to construct and simulate using the Dyson-based perturbative solver#
Setting up a DysonSolver
requires more setup than the standard Solver
, as the
user must specify several configuration parameters, along with the structure of the differential
equation:
The
DysonSolver
requires direct specification of the LMDE to the solver. If we are simulating the Schrodinger equation, we need to multiply the Hamiltonian terms by-1j
when describing the LMDE operators.The
DysonSolver
is a fixed step solver, with the step size being fixed at instantiation. This step size must be chosen in conjunction with theexpansion_order
to ensure that a suitable accuracy is attained.Over each fixed time-step the
DysonSolver
solves by computing a truncated perturbative expansion.To compute the truncated perturbative expansion, the signal envelopes are approximated as a linear combination of Chebyshev polynomials.
The order of the Chebyshev approximations, along with central carrier frequencies for defining the “envelope” of each
Signal
, must be provided at instantiation.
See the DysonSolver
API docs for more details.
For our example Hamiltonian we configure the DysonSolver
as follows:
%%time
from qiskit_dynamics import DysonSolver
dt = 0.1
dyson_solver = DysonSolver(
operators=[-1j * drive_hamiltonian],
rotating_frame=-1j * static_hamiltonian,
dt=dt,
carrier_freqs=[v],
chebyshev_orders=[1],
expansion_order=7,
integration_method='jax_odeint',
atol=1e-12,
rtol=1e-12
)
CPU times: user 2.95 s, sys: 470 ms, total: 3.42 s
Wall time: 3.02 s
The above parameters are chosen so that the DysonSolver
is fast and produces high accuracy
solutions (measured and confirmed after the fact). The relatively large step size dt = 0.1
is
chosen for speed: the larger the step size, the fewer steps required. To ensure high accuracy given
the large step size, we choose a high expansion order, and utilize a linear envelope approximation
scheme by setting the chebyshev_order
to 1
for the single drive signal.
Similar to the Solver
interface, the DysonSolver.solve()
method can be called to
simulate the system for a given list of signals, initial state, start time, and number of time steps
of length dt
.
To properly compare the speed of DysonSolver
to a traditional ODE solver, we write
JAX-compilable functions wrapping each that, given an amplitude value, returns the final unitary
over the interval [0, (T // dt) * dt]
for an on-resonance drive with envelope shape given by
envelope_func
above. Running compiled versions of these functions gives a sense of the speeds
attainable by these solvers.
from qiskit_dynamics import Signal
from jax import jit
# Jit the function to improve performance for repeated calls
@jit
def dyson_sim(amp):
"""For a given envelope amplitude, simulate the final unitary using the
Dyson solver.
"""
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
return dyson_solver.solve(
signals=[drive_signal],
y0=np.eye(dim, dtype=complex),
t0=0.,
n_steps=int(T // dt)
).y[-1]
First run includes compile time.
%time yf_dyson = dyson_sim(1.).block_until_ready()
CPU times: user 647 ms, sys: 29.2 ms, total: 676 ms
Wall time: 660 ms
Once JIT compilation has been performance we can benchmark the performance of the JIT-compiled solver:
%time yf_dyson = dyson_sim(1.).block_until_ready()
CPU times: user 12.3 ms, sys: 0 ns, total: 12.3 ms
Wall time: 4.62 ms
4. Comparison to traditional ODE solver#
We now construct the same simulation using a standard solver to compare accuracy and simulation speed.
from qiskit_dynamics import Solver
solver = Solver(
static_hamiltonian=static_hamiltonian,
hamiltonian_operators=[drive_hamiltonian],
rotating_frame=static_hamiltonian
)
# specify tolerance as an argument to run the simulation at different tolerances
def ode_sim(amp, tol):
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
res = solver.solve(
t_span=[0., int(T // dt) * dt],
y0=np.eye(dim, dtype=complex),
signals=[drive_signal],
method='jax_odeint',
atol=tol,
rtol=tol
)
return res.y[-1]
Simulate with low tolerance for comparison to high accuracy solution.
yf_low_tol = ode_sim(1., 1e-13)
np.linalg.norm(yf_low_tol - yf_dyson)
6.529550206930476e-07
For speed comparison, compile at a tolerance with similar accuracy.
jit_ode_sim = jit(lambda amp: ode_sim(amp, 1e-8))
%time yf_ode = jit_ode_sim(1.).block_until_ready()
CPU times: user 447 ms, sys: 16.1 ms, total: 463 ms
Wall time: 457 ms
Measure compiled time.
%time yf_ode = jit_ode_sim(1.).block_until_ready()
CPU times: user 46.4 ms, sys: 0 ns, total: 46.4 ms
Wall time: 46.2 ms
Confirm similar accuracy solution.
np.linalg.norm(yf_low_tol - yf_ode)
8.67211035081537e-07
Here we see that, once compiled, the Dyson-based solver has a significant speed advantage over the traditional solver, at the expense of the initial compilation time and the technical aspect of using the solver.
5. How-to construct and simulate using the Magnus-based perturbation solver#
Next, we repeat our example using the Magnus-based perturbative solver. Setup of the
MagnusSolver
is similar to the DysonSolver
, but it uses the Magnus expansion and
matrix exponentiation to simulate over each fixed time step.
%%time
from qiskit_dynamics import MagnusSolver
dt = 0.1
magnus_solver = MagnusSolver(
operators=[-1j * drive_hamiltonian],
rotating_frame=-1j * static_hamiltonian,
dt=dt,
carrier_freqs=[v],
chebyshev_orders=[1],
expansion_order=3,
integration_method='jax_odeint',
atol=1e-12,
rtol=1e-12
)
CPU times: user 1.53 s, sys: 45 ms, total: 1.58 s
Wall time: 1.58 s
Setup simulation function.
@jit
def magnus_sim(amp):
drive_signal = Signal(lambda t: amp * envelope_func(t), carrier_freq=v)
return magnus_solver.solve(
signals=[drive_signal],
y0=np.eye(dim, dtype=complex),
t0=0.,
n_steps=int(T // dt)
).y[-1]
First run includes compile time.
%time yf_magnus = magnus_sim(1.).block_until_ready()
CPU times: user 1.28 s, sys: 63.7 ms, total: 1.34 s
Wall time: 1.33 s
Second run demonstrates speed of the simulation.
%time yf_magnus = magnus_sim(1.).block_until_ready()
CPU times: user 24.8 ms, sys: 0 ns, total: 24.8 ms
Wall time: 20 ms
np.linalg.norm(yf_magnus - yf_low_tol)
6.678901371612617e-07
Observe comparable accuracy at a lower order in the expansion, albeit with a modest speed up as compared to the Dyson-based solver.
References