.. _how-to use pulse schedules for jax-jit: How-to use pulse schedules generated by Qiskit Pulse with JAX transformations ============================================================================= .. warning:: This tutorial supresses ``DeprecationWarning`` instances raised by Qiskit Pulse in `qiskit` `1.3`. .. jupyter-execute:: :hide-code: # a parallelism warning raised by JAX is being raised due to somethign outside of Dynamics import warnings warnings.filterwarnings("ignore", message="os.fork") # also silence deprecation warnings from pulse warnings.filterwarnings("ignore", category=DeprecationWarning) Qiskit Pulse enables specification of time-dependence in quantum systems as pulse schedules, built from sequences of a variety of instructions, including the specification of shaped pulses (see the detailed API information about `Qiskit pulse API Reference `__). As of Qiskit 0.40.0, JAX support was added for the :class:`~qiskit.pulse.library.ScalableSymbolicPulse` class. This user guide entry demonstrates the technical elements of utilizing this class within JAX-transformable functions. .. note:: At present, only the :class:`~qiskit.pulse.library.ScalableSymbolicPulse` class is supported by JAX, as the validation present in other pulse types, such as :class:`~qiskit.pulse.library.Gaussian`, is not JAX-compatible. This guide addresses the following topics. See the :ref:`userguide on using JAX ` for a more detailed explanation of how to work with JAX in Qiskit Dynamics. 1. Configure JAX. 2. How to define a Gaussian pulse using :class:`~qiskit.pulse.library.ScalableSymbolicPulse`. 3. JAX transforming Pulse to Signal conversion involving :class:`~qiskit.pulse.library.ScalableSymbolicPulse`. 1. Configure JAX ---------------- First, configure JAX to run on CPU in 64 bit mode. .. jupyter-execute:: # configure jax to use 64 bit mode import jax jax.config.update("jax_enable_x64", True) # tell JAX we are using CPU jax.config.update('jax_platform_name', 'cpu') 2. How to define a Gaussian pulse using :class:`~qiskit.pulse.library.ScalableSymbolicPulse` -------------------------------------------------------------------------------------------- As the standard :class:`~qiskit.pulse.library.Gaussian` pulse is not JAX-compatible, to define a Gaussian pulse to use in optimization, we need to instantiate a :class:`~qiskit.pulse.library.ScalableSymbolicPulse` with a Gaussian parameterization. First, define the symbolic representation in `sympy`. .. jupyter-execute:: from qiskit import pulse from qiskit_dynamics.pulse import InstructionToSignals import sympy as sym dt = 0.222 w = 5. # Helper function that returns a lifted Gaussian symbolic equation. def lifted_gaussian( t: sym.Symbol, center, t_zero, sigma, ) -> sym.Expr: t_shifted = (t - center).expand() t_offset = (t_zero - center).expand() gauss = sym.exp(-((t_shifted / sigma) ** 2) / 2) offset = sym.exp(-((t_offset / sigma) ** 2) / 2) return (gauss - offset) / (1 - offset) Next, define the :class:`~qiskit.pulse.library.ScalableSymbolicPulse` using the above expression. .. jupyter-execute:: _t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle") _center = _duration / 2 envelope_expr = ( _amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma) ) gaussian_pulse = pulse.ScalableSymbolicPulse( pulse_type="Gaussian", duration=160, amp=0.3, angle=0, parameters={"sigma": 40}, envelope=envelope_expr, constraints=_sigma > 0, valid_amp_conditions=sym.Abs(_amp) <= 1.0, ) gaussian_pulse.draw() 3. JAX transforming Pulse to Signal conversion involving :class:`~qiskit.pulse.library.ScalableSymbolicPulse` ------------------------------------------------------------------------------------------------------------- Using a Gaussian pulse as an example, we show that a function involving :class:`~qiskit.pulse.library.ScalableSymbolicPulse` and the pulse to signal converter can be JAX-compiled (or more generally, JAX-transformed). .. jupyter-execute:: # use amplitude as the function argument def jit_func(amp): _t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle") _center = _duration / 2 envelope_expr = ( _amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma) ) # we need to set disable_validation True to enable jax-jitting. pulse.ScalableSymbolicPulse.disable_validation = True gaussian_pulse = pulse.ScalableSymbolicPulse( pulse_type="Gaussian", duration=160, amp=amp, angle=0, parameters={"sigma": 40}, envelope=envelope_expr, constraints=_sigma > 0, valid_amp_conditions=sym.Abs(_amp) <= 1.0, ) # build a pulse schedule with pulse.build() as schedule: pulse.play(gaussian_pulse, pulse.DriveChannel(0)) # convert from a pulse schedule to a list of signals converter = InstructionToSignals(dt, carriers={"d0": w}) return converter.get_signals(schedule)[0].samples jax.jit(jit_func)(0.4)