Gradient optimization of a pulse sequence#

Here, we walk through an example of optimizing a single-qubit gate using Qiskit Dynamics. This tutorial requires JAX - see the user guide on How-to use JAX with qiskit-dynamics.

We will optimize an X-gate on a model of a qubit system using the following steps:

  1. Configure JAX.

  2. Setup a Solver instance with the model of the system.

  3. Define a pulse sequence parameterization to optimize over.

  4. Define a gate fidelity function.

  5. Define an objective function for optimization.

  6. Use JAX to differentiate the objective, then do the gradient optimization.

  7. Repeat the X-gate optimization, alternatively using pulse schedules to specify the control sequence.

1. Configure JAX#

First, set JAX to operate in 64-bit mode and to run on CPU.

import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

import jax.numpy as jnp

2. Setup the solver#

Here we will setup a Solver with a simple model of a qubit. The Hamiltonian is:

H(t)=2πνZ2+2πrs(t)X2

In the above:

  • ν is the qubit frequency,

  • r is the drive strength,

  • s(t) is the drive signal which we will optimize, and

  • X and Z are the Pauli X and Z operators.

We will setup the problem to be in the rotating frame of the drift term.

import numpy as np
from qiskit.quantum_info import Operator
from qiskit_dynamics import Solver

v = 5.
r = 0.02

static_hamiltonian = 2 * np.pi * v * Operator.from_label('Z') / 2
drive_term = 2 * np.pi * r * Operator.from_label('X') / 2

ham_solver = Solver(
    hamiltonian_operators=[drive_term],
    static_hamiltonian=static_hamiltonian,
    rotating_frame=static_hamiltonian,
)

3. Define a pulse sequence parameterization to optimize over#

We will optimize over signals that are:

  • On resonance with piecewise constant envelopes,

  • Envelopes bounded between [1,1],

  • Envelopes are smooth, in the sense that the change between adjacent samples is small, and

  • Envelope starts and ends at 0.

In setting up our parameterization, we need t keep in mind that we will use the BFGS optimization routine, and hence:

  • Optimization parameters must be unconstrained.

  • Parameterization must be JAX-differentiable.

We implement a parameterization as follows:

  • Input: Array x of real values.

  • “Normalize” x by applying a JAX-differentiable function from R[1,1].

  • Pad the normalized x with a 0. to start.

  • “Smoothen” the above via convolution.

  • Construct the signal using the above as the samples for a piecewise-constant envelope, with carrier frequency on resonance.

We remark that there are many other parameterizations that may achieve the same ends, and may have more efficient strategies for achieving a value of 0 at the beginning and end of the pulse. This is only meant to demonstrate the need for such an approach, and one simple example of one.

from qiskit_dynamics import DiscreteSignal
from qiskit_dynamics.signals import Convolution

import jax.numpy as jnp

# define convolution filter
def gaus(t):
    sigma = 15
    _dt = 0.1
    return 2.*_dt/np.sqrt(2.*np.pi*sigma**2)*np.exp(-t**2/(2*sigma**2))

convolution = Convolution(gaus)

# define function mapping parameters to signals
def signal_mapping(params):

    # map samples into [-1, 1]
    bounded_samples = jnp.arctan(params) / (np.pi / 2)

    # pad with 0 at beginning
    padded_samples = jnp.append(jnp.array([0], dtype=complex), bounded_samples)

    # apply filter
    output_signal = convolution(DiscreteSignal(dt=1., samples=padded_samples))

    # set carrier frequency to v
    output_signal.carrier_freq = v

    return output_signal

Observe, for example, the signal generated when all parameters are 108:

signal = signal_mapping(np.ones(80) * 1e8)
signal.draw(t0=0., tf=signal.duration * signal.dt, n=1000, function='envelope')
../_images/optimizing_pulse_sequence_3_0.png

4. Define gate fidelity#

We will optimize an X gate, and define the fidelity of the unitary U implemented by the pulse via the standard fidelity measure:

f(U)=|Tr(XU)|24
X_op = Operator.from_label('X').data

def fidelity(U):
    return jnp.abs(jnp.sum(X_op * U))**2 / 4.

5. Define the objective function#

The function we want to optimize consists of:

  • Taking a list of input samples and applying the signal mapping.

  • Simulating the Schrodinger equation over the length of the pulse sequence.

  • Computing and return the infidelity (we minimize 1f(U)).

def objective(params):

    # apply signal mapping and set signals
    signal = signal_mapping(params)

    # Simulate
    results = ham_solver.solve(
        y0=np.eye(2, dtype=complex),
        t_span=[0, signal.duration * signal.dt],
        signals=[signal],
        method='jax_odeint',
        atol=1e-8,
        rtol=1e-8
    )
    U = results.y[-1]

    # compute and return infidelity
    fid = fidelity(U)
    return 1. - fid

6. Perform JAX transformations and optimize#

Finally, we gradient optimize the objective:

  • Use jax.value_and_grad to transform the objective into a function that computes both the objective and the gradient.

  • Use jax.jit to just-in-time compile the function into optimized XLA code. For the initial cost of performing the compilation, this speeds up each call of the function, speeding up the optimization.

  • Call scipy.optimize.minimize with the above, with method='BFGS' and jac=True to indicate that the passed objective also computes the gradient.

from jax import jit, value_and_grad
from scipy.optimize import minimize

jit_grad_obj = jit(value_and_grad(objective))

initial_guess = np.random.rand(80) - 0.5

opt_results = minimize(fun=jit_grad_obj, x0=initial_guess, jac=True, method='BFGS')
print(opt_results.message)
print('Number of function evaluations: ' + str(opt_results.nfev))
print('Function value: ' + str(opt_results.fun))
Optimization terminated successfully.
Number of function evaluations: 12
Function value: -6.07233219263037e-08

The gate is optimized to an X gate, with deviation within the numerical accuracy of the solver.

We can draw the optimized signal, which is retrieved by applying the signal_mapping to the optimized parameters.

opt_signal = signal_mapping(opt_results.x)

opt_signal.draw(
    t0=0,
    tf=opt_signal.duration * opt_signal.dt,
    n=1000,
    function='envelope',
    title='Optimized envelope'
)
../_images/optimizing_pulse_sequence_7_0.png

Summing the signal samples yields approximately ±50, which is equivalent to what one would expect based on a rotating wave approximation analysis.

opt_signal.samples.sum()
Array(50.00040529, dtype=float64)

7. Repeat the X-gate optimization, alternatively using pulse schedules to specify the control sequence#

Here, we perform the optimization again, however now we specify the parameterized control sequence to optimize as a pulse schedule.

We construct a Gaussian square pulse as a ScalableSymbolicPulse instance, parameterized by sigma and width. Although qiskit pulse provides a GaussianSquare, this class is not JAX compatible. See the user guide entry on JAX-compatible pulse schedules.

import sympy as sym
from qiskit import pulse

def lifted_gaussian(
    t: sym.Symbol,
    center,
    t_zero,
    sigma,
) -> sym.Expr:
    t_shifted = (t - center).expand()
    t_offset = (t_zero - center).expand()

    gauss = sym.exp(-((t_shifted / sigma) ** 2) / 2)
    offset = sym.exp(-((t_offset / sigma) ** 2) / 2)

    return (gauss - offset) / (1 - offset)

def gaussian_square_generated_by_pulse(params):

    sigma, width = params
    _t, _duration, _amp, _sigma, _width, _angle = sym.symbols(
        "t, duration, amp, sigma, width, angle"
    )
    _center = _duration / 2

    _sq_t0 = _center - _width / 2
    _sq_t1 = _center + _width / 2

    _gaussian_ledge = lifted_gaussian(_t, _sq_t0, -1, _sigma)
    _gaussian_redge = lifted_gaussian(_t, _sq_t1, _duration + 1, _sigma)

    envelope_expr = (
        _amp
        * sym.exp(sym.I * _angle)
        * sym.Piecewise(
            (_gaussian_ledge, _t <= _sq_t0), (_gaussian_redge, _t >= _sq_t1), (1, True)
        )
    )

    # we need to set disable_validation True to enable jax-jitting.
    pulse.ScalableSymbolicPulse.disable_validation = True

    return pulse.ScalableSymbolicPulse(
            pulse_type="GaussianSquare",
            duration=230,
            amp=1,
            angle=0,
            parameters={"sigma": sigma, "width": width},
            envelope=envelope_expr,
            constraints=sym.And(_sigma > 0, _width >= 0, _duration >= _width),
            valid_amp_conditions=sym.Abs(_amp) <= 1.0,
        )

Next, we construct a pulse schedule using the above parametrized Gaussian square pulse, convert it to a signal, and simulate the equation over the length of the pulse sequence.

from qiskit_dynamics.pulse import InstructionToSignals

dt = 0.222
w = 5.

def objective(params):

    instance = gaussian_square_generated_by_pulse(params)

    with pulse.build() as Xp:
        pulse.play(instance, pulse.DriveChannel(0))

    converter = InstructionToSignals(dt, carriers={"d0": w})
    signal = converter.get_signals(Xp)

    result = ham_solver.solve(
        y0=np.eye(2, dtype=complex),
        t_span=[0, instance.duration * dt],
        signals=[signal],
        method='jax_odeint',
        atol=1e-8,
        rtol=1e-8
    )
    return 1. - fidelity(result[0].y[-1])

We set the initial values of sigma and width for the optimization as initial_params = np.array([10, 10]).

initial_params = np.array([10, 10])
gaussian_square_generated_by_pulse(initial_params).draw()
../_images/optimizing_pulse_sequence_11_0.png
from jax import jit, value_and_grad
from scipy.optimize import minimize

jit_grad_obj = jit(value_and_grad(objective))

initial_params = np.array([10,10])


opt_results = minimize(fun=jit_grad_obj, x0=initial_params, jac=True, method='BFGS')

print(opt_results.message)
print(f"Optimized Sigma is {opt_results.x[0]} and Width is {opt_results.x[1]}")
print('Number of function evaluations: ' + str(opt_results.nfev))
print('Function value: ' + str(opt_results.fun))
Optimization terminated successfully.
Optimized Sigma is 516.3449186228395 and Width is 212.18189863023863
Number of function evaluations: 14
Function value: 1.7537254681787573e-07

We can draw the optimized pulse, whose parameters are retrieved by opt_results.x.

gaussian_square_generated_by_pulse(opt_results.x).draw()
../_images/optimizing_pulse_sequence_13_0.png