How-to use pulse schedules generated by qiskit-pulse with JAX transformations#

Qiskit-pulse enables specification of time-dependence in quantum systems as pulse schedules, built from sequences of a variety of instructions, including the specification of shaped pulses (see the detailed API information about Qiskit pulse API Reference). As of qiskit-terra 0.23.0, JAX support was added for the ScalableSymbolicPulse class. This user guide entry demonstrates the technical elements of utilizing this class within JAX-transformable functions.

Note

At present, only the ScalableSymbolicPulse class is supported by JAX, as the validation present in other pulse types, such as Gaussian, is not JAX-compatible.

This guide addresses the following topics. See the userguide on using JAX for a more detailed explanation of how to work with JAX in Qiskit Dynamics.

  1. Configure to use JAX.

  2. How to define a Gaussian pulse using ScalableSymbolicPulse.

  3. JAX transforming Pulse to Signal conversion involving ScalableSymbolicPulse.

1. Configure to use JAX#

First, configure Dynamics to use JAX.

# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

# import Array and set default backend
from qiskit_dynamics.array import Array
Array.set_default_backend('jax')

2. How to define a Gaussian pulse using ScalableSymbolicPulse#

As the standard Gaussian pulse is not JAX-compatible, to define a Gaussian pulse to use in optimization, we need to instantiate a ScalableSymbolicPulse with a Gaussian parameterization. First, define the symbolic representation in sympy.

from qiskit import pulse
from qiskit_dynamics.pulse import InstructionToSignals
import sympy as sym

dt = 0.222
w = 5.

# Helper function that returns a lifted Gaussian symbolic equation.
def lifted_gaussian(
    t: sym.Symbol,
    center,
    t_zero,
    sigma,
) -> sym.Expr:
    t_shifted = (t - center).expand()
    t_offset = (t_zero - center).expand()

    gauss = sym.exp(-((t_shifted / sigma) ** 2) / 2)
    offset = sym.exp(-((t_offset / sigma) ** 2) / 2)

    return (gauss - offset) / (1 - offset)

Next, define the ScalableSymbolicPulse using the above expression.

_t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle")
_center = _duration / 2

envelope_expr = (
    _amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
)

gaussian_pulse = pulse.ScalableSymbolicPulse(
        pulse_type="Gaussian",
        duration=160,
        amp=0.3,
        angle=0,
        parameters={"sigma": 40},
        envelope=envelope_expr,
        constraints=_sigma > 0,
        valid_amp_conditions=sym.Abs(_amp) <= 1.0,
    )

gaussian_pulse.draw()
../_images/how_to_use_pulse_schedule_for_jax_jit_2_0.png

3. JAX transforming Pulse to Signal conversion involving ScalableSymbolicPulse#

Using a Gaussian pulse as an example, we show that a function involving ScalableSymbolicPulse and the pulse to signal converter can be JAX-compiled (or more generally, JAX-transformed).

# use amplitude as the function argument
def jit_func(amp):
    _t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle")
    _center = _duration / 2

    envelope_expr = (
        _amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
    )

    # we need to set disable_validation True to enable jax-jitting.
    pulse.ScalableSymbolicPulse.disable_validation = True

    gaussian_pulse = pulse.ScalableSymbolicPulse(
            pulse_type="Gaussian",
            duration=160,
            amp=amp,
            angle=0,
            parameters={"sigma": 40},
            envelope=envelope_expr,
            constraints=_sigma > 0,
            valid_amp_conditions=sym.Abs(_amp) <= 1.0,
        )

    # build a pulse schedule
    with pulse.build() as schedule:
        pulse.play(gaussian_pulse, pulse.DriveChannel(0))

    # convert from a pulse schedule to a list of signals
    converter = InstructionToSignals(dt, carriers={"d0": w})

    return converter.get_signals(schedule)[0].samples.data

jax.jit(jit_func)(0.4)
Array([0.00461643+0.j, 0.00784044+0.j, 0.01118371+0.j, 0.0146479 +0.j,
       0.01823455+0.j, 0.02194501+0.j, 0.02578049+0.j, 0.029742  +0.j,
       0.03383034+0.j, 0.03804615+0.j, 0.0423898 +0.j, 0.04686147+0.j,
       0.05146109+0.j, 0.05618834+0.j, 0.06104264+0.j, 0.06602316+0.j,
       0.07112877+0.j, 0.07635807+0.j, 0.08170936+0.j, 0.08718063+0.j,
       0.0927696 +0.j, 0.09847362+0.j, 0.10428977+0.j, 0.11021477+0.j,
       0.11624505+0.j, 0.12237668+0.j, 0.12860541+0.j, 0.13492665+0.j,
       0.14133549+0.j, 0.14782668+0.j, 0.15439464+0.j, 0.16103348+0.j,
       0.16773697+0.j, 0.17449859+0.j, 0.18131147+0.j, 0.1881685 +0.j,
       0.19506222+0.j, 0.20198494+0.j, 0.20892866+0.j, 0.21588517+0.j,
       0.22284598+0.j, 0.22980239+0.j, 0.2367455 +0.j, 0.24366621+0.j,
       0.25055524+0.j, 0.25740317+0.j, 0.26420043+0.j, 0.27093735+0.j,
       0.27760417+0.j, 0.28419106+0.j, 0.29068813+0.j, 0.29708551+0.j,
       0.30337328+0.j, 0.3095416 +0.j, 0.31558066+0.j, 0.32148073+0.j,
       0.32723219+0.j, 0.33282555+0.j, 0.33825149+0.j, 0.34350085+0.j,
       0.34856471+0.j, 0.35343437+0.j, 0.35810137+0.j, 0.36255757+0.j,
       0.36679511+0.j, 0.37080648+0.j, 0.3745845 +0.j, 0.37812239+0.j,
       0.38141374+0.j, 0.38445258+0.j, 0.38723335+0.j, 0.38975094+0.j,
       0.39200072+0.j, 0.39397853+0.j, 0.39568069+0.j, 0.39710405+0.j,
       0.39824594+0.j, 0.39910423+0.j, 0.39967732+0.j, 0.39996414+0.j,
       0.39996414+0.j, 0.39967732+0.j, 0.39910423+0.j, 0.39824594+0.j,
       0.39710405+0.j, 0.39568069+0.j, 0.39397853+0.j, 0.39200072+0.j,
       0.38975094+0.j, 0.38723335+0.j, 0.38445258+0.j, 0.38141374+0.j,
       0.37812239+0.j, 0.3745845 +0.j, 0.37080648+0.j, 0.36679511+0.j,
       0.36255757+0.j, 0.35810137+0.j, 0.35343437+0.j, 0.34856471+0.j,
       0.34350085+0.j, 0.33825149+0.j, 0.33282555+0.j, 0.32723219+0.j,
       0.32148073+0.j, 0.31558066+0.j, 0.3095416 +0.j, 0.30337328+0.j,
       0.29708551+0.j, 0.29068813+0.j, 0.28419106+0.j, 0.27760417+0.j,
       0.27093735+0.j, 0.26420043+0.j, 0.25740317+0.j, 0.25055524+0.j,
       0.24366621+0.j, 0.2367455 +0.j, 0.22980239+0.j, 0.22284598+0.j,
       0.21588517+0.j, 0.20892866+0.j, 0.20198494+0.j, 0.19506222+0.j,
       0.1881685 +0.j, 0.18131147+0.j, 0.17449859+0.j, 0.16773697+0.j,
       0.16103348+0.j, 0.15439464+0.j, 0.14782668+0.j, 0.14133549+0.j,
       0.13492665+0.j, 0.12860541+0.j, 0.12237668+0.j, 0.11624505+0.j,
       0.11021477+0.j, 0.10428977+0.j, 0.09847362+0.j, 0.0927696 +0.j,
       0.08718063+0.j, 0.08170936+0.j, 0.07635807+0.j, 0.07112877+0.j,
       0.06602316+0.j, 0.06104264+0.j, 0.05618834+0.j, 0.05146109+0.j,
       0.04686147+0.j, 0.0423898 +0.j, 0.03804615+0.j, 0.03383034+0.j,
       0.029742  +0.j, 0.02578049+0.j, 0.02194501+0.j, 0.01823455+0.j,
       0.0146479 +0.j, 0.01118371+0.j, 0.00784044+0.j, 0.00461643+0.j],      dtype=complex128)