How-to use pulse schedules generated by qiskit-pulse
with JAX transformations#
Qiskit-pulse
enables specification of time-dependence in quantum systems as pulse schedules,
built from sequences of a variety of instructions, including the specification of shaped pulses (see
the detailed API information about Qiskit pulse API Reference). As of qiskit-terra
0.23.0, JAX support
was added for the ScalableSymbolicPulse
class. This user guide entry
demonstrates the technical elements of utilizing this class within JAX-transformable functions.
Note
At present, only the ScalableSymbolicPulse
class is supported by
JAX, as the validation present in other pulse types, such as
Gaussian
, is not JAX-compatible.
This guide addresses the following topics. See the userguide on using JAX for a more detailed explanation of how to work with JAX in Qiskit Dynamics.
Configure to use JAX.
How to define a Gaussian pulse using
ScalableSymbolicPulse
.JAX transforming Pulse to Signal conversion involving
ScalableSymbolicPulse
.
1. Configure to use JAX#
First, configure Dynamics to use JAX.
# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)
# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')
# import Array and set default backend
from qiskit_dynamics.array import Array
Array.set_default_backend('jax')
2. How to define a Gaussian pulse using ScalableSymbolicPulse
#
As the standard Gaussian
pulse is not JAX-compatible, to define a
Gaussian pulse to use in optimization, we need to instantiate a
ScalableSymbolicPulse
with a Gaussian parameterization. First, define
the symbolic representation in sympy.
from qiskit import pulse
from qiskit_dynamics.pulse import InstructionToSignals
import sympy as sym
dt = 0.222
w = 5.
# Helper function that returns a lifted Gaussian symbolic equation.
def lifted_gaussian(
t: sym.Symbol,
center,
t_zero,
sigma,
) -> sym.Expr:
t_shifted = (t - center).expand()
t_offset = (t_zero - center).expand()
gauss = sym.exp(-((t_shifted / sigma) ** 2) / 2)
offset = sym.exp(-((t_offset / sigma) ** 2) / 2)
return (gauss - offset) / (1 - offset)
Next, define the ScalableSymbolicPulse
using the above expression.
_t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle")
_center = _duration / 2
envelope_expr = (
_amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
)
gaussian_pulse = pulse.ScalableSymbolicPulse(
pulse_type="Gaussian",
duration=160,
amp=0.3,
angle=0,
parameters={"sigma": 40},
envelope=envelope_expr,
constraints=_sigma > 0,
valid_amp_conditions=sym.Abs(_amp) <= 1.0,
)
gaussian_pulse.draw()
3. JAX transforming Pulse to Signal conversion involving ScalableSymbolicPulse
#
Using a Gaussian pulse as an example, we show that a function involving
ScalableSymbolicPulse
and the pulse to signal converter can be
JAX-compiled (or more generally, JAX-transformed).
# use amplitude as the function argument
def jit_func(amp):
_t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle")
_center = _duration / 2
envelope_expr = (
_amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
)
# we need to set disable_validation True to enable jax-jitting.
pulse.ScalableSymbolicPulse.disable_validation = True
gaussian_pulse = pulse.ScalableSymbolicPulse(
pulse_type="Gaussian",
duration=160,
amp=amp,
angle=0,
parameters={"sigma": 40},
envelope=envelope_expr,
constraints=_sigma > 0,
valid_amp_conditions=sym.Abs(_amp) <= 1.0,
)
# build a pulse schedule
with pulse.build() as schedule:
pulse.play(gaussian_pulse, pulse.DriveChannel(0))
# convert from a pulse schedule to a list of signals
converter = InstructionToSignals(dt, carriers={"d0": w})
return converter.get_signals(schedule)[0].samples.data
jax.jit(jit_func)(0.4)
Array([0.00461643+0.j, 0.00784044+0.j, 0.01118371+0.j, 0.0146479 +0.j,
0.01823455+0.j, 0.02194501+0.j, 0.02578049+0.j, 0.029742 +0.j,
0.03383034+0.j, 0.03804615+0.j, 0.0423898 +0.j, 0.04686147+0.j,
0.05146109+0.j, 0.05618834+0.j, 0.06104264+0.j, 0.06602316+0.j,
0.07112877+0.j, 0.07635807+0.j, 0.08170936+0.j, 0.08718063+0.j,
0.0927696 +0.j, 0.09847362+0.j, 0.10428977+0.j, 0.11021477+0.j,
0.11624505+0.j, 0.12237668+0.j, 0.12860541+0.j, 0.13492665+0.j,
0.14133549+0.j, 0.14782668+0.j, 0.15439464+0.j, 0.16103348+0.j,
0.16773697+0.j, 0.17449859+0.j, 0.18131147+0.j, 0.1881685 +0.j,
0.19506222+0.j, 0.20198494+0.j, 0.20892866+0.j, 0.21588517+0.j,
0.22284598+0.j, 0.22980239+0.j, 0.2367455 +0.j, 0.24366621+0.j,
0.25055524+0.j, 0.25740317+0.j, 0.26420043+0.j, 0.27093735+0.j,
0.27760417+0.j, 0.28419106+0.j, 0.29068813+0.j, 0.29708551+0.j,
0.30337328+0.j, 0.3095416 +0.j, 0.31558066+0.j, 0.32148073+0.j,
0.32723219+0.j, 0.33282555+0.j, 0.33825149+0.j, 0.34350085+0.j,
0.34856471+0.j, 0.35343437+0.j, 0.35810137+0.j, 0.36255757+0.j,
0.36679511+0.j, 0.37080648+0.j, 0.3745845 +0.j, 0.37812239+0.j,
0.38141374+0.j, 0.38445258+0.j, 0.38723335+0.j, 0.38975094+0.j,
0.39200072+0.j, 0.39397853+0.j, 0.39568069+0.j, 0.39710405+0.j,
0.39824594+0.j, 0.39910423+0.j, 0.39967732+0.j, 0.39996414+0.j,
0.39996414+0.j, 0.39967732+0.j, 0.39910423+0.j, 0.39824594+0.j,
0.39710405+0.j, 0.39568069+0.j, 0.39397853+0.j, 0.39200072+0.j,
0.38975094+0.j, 0.38723335+0.j, 0.38445258+0.j, 0.38141374+0.j,
0.37812239+0.j, 0.3745845 +0.j, 0.37080648+0.j, 0.36679511+0.j,
0.36255757+0.j, 0.35810137+0.j, 0.35343437+0.j, 0.34856471+0.j,
0.34350085+0.j, 0.33825149+0.j, 0.33282555+0.j, 0.32723219+0.j,
0.32148073+0.j, 0.31558066+0.j, 0.3095416 +0.j, 0.30337328+0.j,
0.29708551+0.j, 0.29068813+0.j, 0.28419106+0.j, 0.27760417+0.j,
0.27093735+0.j, 0.26420043+0.j, 0.25740317+0.j, 0.25055524+0.j,
0.24366621+0.j, 0.2367455 +0.j, 0.22980239+0.j, 0.22284598+0.j,
0.21588517+0.j, 0.20892866+0.j, 0.20198494+0.j, 0.19506222+0.j,
0.1881685 +0.j, 0.18131147+0.j, 0.17449859+0.j, 0.16773697+0.j,
0.16103348+0.j, 0.15439464+0.j, 0.14782668+0.j, 0.14133549+0.j,
0.13492665+0.j, 0.12860541+0.j, 0.12237668+0.j, 0.11624505+0.j,
0.11021477+0.j, 0.10428977+0.j, 0.09847362+0.j, 0.0927696 +0.j,
0.08718063+0.j, 0.08170936+0.j, 0.07635807+0.j, 0.07112877+0.j,
0.06602316+0.j, 0.06104264+0.j, 0.05618834+0.j, 0.05146109+0.j,
0.04686147+0.j, 0.0423898 +0.j, 0.03804615+0.j, 0.03383034+0.j,
0.029742 +0.j, 0.02578049+0.j, 0.02194501+0.j, 0.01823455+0.j,
0.0146479 +0.j, 0.01118371+0.j, 0.00784044+0.j, 0.00461643+0.j], dtype=complex128)