Source code for qiskit_qec.utils.stim_tools

# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=invalid-name, disable=no-name-in-module, disable=unused-argument

"""Tools to use functionality from Stim."""
from typing import Union, List, Dict, Callable
from math import log as loga
from stim import Circuit as StimCircuit
from stim import DetectorErrorModel as StimDetectorErrorModel
from stim import DemInstruction as StimDemInstruction
from stim import DemRepeatBlock as StimDemRepeatBlock
from stim import DemTarget as StimDemTarget
from stim import target_rec as StimTarget_rec

import numpy as np
import rustworkx as rx

from qiskit import QuantumCircuit
from qiskit_aer.noise.errors.base_quantum_error import QuantumChannelInstruction
from qiskit_aer.noise import pauli_error
from qiskit_qec.utils.decoding_graph_attributes import (
    DecodingGraphNode,
    DecodingGraphEdge,
)
from qiskit_qec.noise.paulinoisemodel import PauliNoiseModel


[docs] def get_stim_circuits( circuit: Union[QuantumCircuit, List], detectors: List[Dict] = None, logicals: List[Dict] = None, ): """Converts compatible qiskit circuits to stim circuits. Dictionaries are not complete. For the stim definitions see: https://github.com/quantumlib/Stim/blob/main/doc/gates.md Args: circuit: Compatible gates are Paulis, controlled Paulis, h, s, and sdg, swap, reset, measure and barrier. Compatible noise operators correspond to a single or two qubit pauli channel. detectors: A list of measurement comparisons. A measurement comparison (detector) is either a list of measurements given by a the name and index of the classical bit or a list of dictionaries, with a mandatory clbits key containing the classical bits. A dictionary can contain keys like 'qubits', 'time', 'basis' etc. logicals: A list of logical measurements. A logical measurement is a list of classical bits whose total parity is the logical eigenvalue. Again it can be a list of dictionaries. Returns: stim_circuits, stim_measurement_data """ if detectors is None: detectors = [{}] if logicals is None: logicals = [{}] if len(detectors) > 0 and isinstance(detectors[0], List): detectors = [{"clbits": det, "qubits": [], "time": 0} for det in detectors] if len(logicals) > 0 and isinstance(logicals[0], List): logicals = [{"clbits": log} for log in logicals] stim_circuits = [] stim_measurement_data = [] if isinstance(circuit, QuantumCircuit): circuit = [circuit] for circ in circuit: stim_circuit = StimCircuit() qiskit_to_stim_dict = { "id": "I", "x": "X", "y": "Y", "z": "Z", "h": "H", "s": "S", "sdg": "S_DAG", "cx": "CX", "cy": "CY", "cz": "CZ", "swap": "SWAP", "reset": "R", "measure": "M", "barrier": "TICK", } pauli_error_1_stim_order = { "id": 0, "I": 0, "X": 1, "x": 1, "Y": 2, "y": 2, "Z": 3, "z": 3, } pauli_error_2_stim_order = { "II": 0, "IX": 1, "IY": 2, "IZ": 3, "XI": 4, "XX": 5, "XY": 6, "XZ": 7, "YI": 8, "YX": 9, "YY": 10, "YZ": 11, "ZI": 12, "ZX": 13, "ZY": 14, "ZZ": 15, } measurement_data = [] qreg_offset = {} creg_offset = {} prevq_offset = 0 prevc_offset = 0 for instruction in circ.data: inst = instruction.operation qargs = instruction.qubits cargs = instruction.clbits for qubit in qargs: if qubit._register.name not in qreg_offset: qreg_offset[qubit._register.name] = prevq_offset prevq_offset += qubit._register.size for bit in cargs: if bit._register.name not in creg_offset: creg_offset[bit._register.name] = prevc_offset prevc_offset += bit._register.size qubit_indices = [ qargs[i]._index + qreg_offset[qargs[i]._register.name] for i in range(len(qargs)) ] if isinstance(inst, QuantumChannelInstruction): qerror = inst._quantum_error pauli_errors_types = qerror.to_dict()["instructions"] pauli_probs = qerror.to_dict()["probabilities"] if pauli_errors_types[0][0]["name"] in pauli_error_1_stim_order: probs = 4 * [0.0] for pind, ptype in enumerate(pauli_errors_types): probs[pauli_error_1_stim_order[ptype[0]["name"]]] = pauli_probs[pind] stim_circuit.append("PAULI_CHANNEL_1", qubit_indices, probs[1:]) elif pauli_errors_types[0][0]["params"][0] in pauli_error_2_stim_order: # here the name is always 'pauli' and the params gives the Pauli type probs = 16 * [0.0] for pind, ptype in enumerate(pauli_errors_types): probs[pauli_error_2_stim_order[ptype[0]["params"][0]]] = pauli_probs[pind] stim_circuit.append("PAULI_CHANNEL_2", qubit_indices, probs[1:]) else: raise Exception("Unexpected operations: " + str([inst, qargs, cargs])) else: # Gates and measurements if inst.name in qiskit_to_stim_dict: if len(cargs) > 0: # keeping track of measurement indices in stim measurement_data.append([cargs[0]._register.name, cargs[0]._index]) if qiskit_to_stim_dict[inst.name] == "TICK": # barrier stim_circuit.append("TICK") elif inst.condition is not None: # handle c_ifs if inst.name in "xyz": if inst.condition[1] == 1: clbit = inst.condition[0] stim_circuit.append( qiskit_to_stim_dict["c" + inst.name], [ StimTarget_rec( measurement_data.index( [clbit._register.name, clbit._index] ) - len(measurement_data) ), qubit_indices[0], ], ) else: raise Exception( "Classically controlled gate must be conditioned on bit value 1" ) else: raise Exception( "Classically controlled " + inst.name + " gate is not supported" ) else: # gates/measurements acting on qubits stim_circuit.append(qiskit_to_stim_dict[inst.name], qubit_indices) else: raise Exception("Unexpected operations: " + str([inst, qargs, cargs])) if detectors != [{}]: for det in detectors: stim_record_targets = [] for reg, ind in det["clbits"]: stim_record_targets.append( StimTarget_rec(measurement_data.index([reg, ind]) - len(measurement_data)) ) if det["time"] != []: stim_circuit.append( "DETECTOR", stim_record_targets, det["qubits"] + [det["time"]] ) else: stim_circuit.append("DETECTOR", stim_record_targets, []) if logicals != [{}]: for log_ind, log in enumerate(logicals): stim_record_targets = [] for reg, ind in log["clbits"]: stim_record_targets.append( StimTarget_rec(measurement_data.index([reg, ind]) - len(measurement_data)) ) stim_circuit.append("OBSERVABLE_INCLUDE", stim_record_targets, log_ind) stim_circuits.append(stim_circuit) stim_measurement_data.append(measurement_data) return stim_circuits, stim_measurement_data
[docs] def get_counts_via_stim( circuits: Union[List, QuantumCircuit], shots: int = 4000, noise_model: PauliNoiseModel = None, ): """Returns a qiskit compatible dictionary of measurement outcomes Args: circuits: Qiskit circuit compatible with `get_stim_circuits` or list thereof. shots: Number of samples to be generated. noise_model: Pauli noise model for any additional noise to be applied. Returns: counts: Counts dictionary in standard Qiskit form or list thereof. """ if noise_model: circuits = noisify_circuit(circuits, noise_model) single_circuit = isinstance(circuits, QuantumCircuit) if single_circuit: circuits = [circuits] counts = [] for circuit in circuits: stim_circuits, stim_measurement_data = get_stim_circuits(circuit) stim_circuit = stim_circuits[0] measurement_data = stim_measurement_data[0] stim_samples = stim_circuit.compile_sampler().sample(shots=shots) qiskit_counts = {} for stim_sample in stim_samples: prev_reg = measurement_data[-1][0] qiskit_count = "" for idx, meas in enumerate(measurement_data[::-1]): reg, _ = meas if reg != prev_reg: qiskit_count += " " qiskit_count += str(int(stim_sample[-idx - 1])) prev_reg = reg if qiskit_count in qiskit_counts: qiskit_counts[qiskit_count] += 1 else: qiskit_counts[qiskit_count] = 1 counts.append(qiskit_counts) if single_circuit: counts = counts[0] return counts
def iter_flatten_model( model: StimDetectorErrorModel, handle_error: Callable[[float, List[int], List[int]], None], handle_detector_coords: Callable[[int, np.ndarray], None], detectors: List[Dict], hyperedges: List[Dict], ): """ This function have been copied from the built-in method of stim: stim.Circuit.generated("surface_code:rotated_memory_z",...) """ det_offset = 0 def _helper(m: StimDetectorErrorModel, reps: int): nonlocal det_offset for _ in range(reps): for instruction in m: if isinstance(instruction, StimDemRepeatBlock): _helper(instruction.body_copy(), instruction.repeat_count) elif isinstance(instruction, StimDemInstruction): if instruction.type == "error": dets: List[int] = [] frames: List[int] = [] t: StimDemTarget p = instruction.args_copy()[0] hyperedge = {} for t in instruction.targets_copy(): if t.is_relative_detector_id(): dets.append(t.val + det_offset) elif t.is_logical_observable_id(): frames.append(t.val) elif t.is_separator(): # Treat each component of a decomposed error as an independent error. handle_error(p, dets, frames, hyperedge) frames = [] dets = [] # Handle last component. handle_error(p, dets, frames, hyperedge) if len(hyperedge) > 1: hyperedges.append(hyperedge) elif instruction.type == "shift_detectors": det_offset += instruction.targets_copy()[0] elif instruction.type == "detector": t = instruction.targets_copy()[0] det_ind = t.val + det_offset if detectors == [{}]: a = np.array(instruction.args_copy()) time = a[-1] qubits = [int(qubit_ind) for qubit_ind in a[:-1]] det = {} else: det = detectors[det_ind].copy() time = det.pop("time") qubits = det.pop("qubits") del det["clbits"] for t in instruction.targets_copy(): handle_detector_coords( detector_index=det_ind, time=time, qubits=qubits, det_props=det, ) elif instruction.type == "logical_observable": pass else: raise NotImplementedError() else: raise NotImplementedError() _helper(model, 1) def detector_error_model_to_rx_graph( model: StimDetectorErrorModel, detectors: List[Dict] = None ) -> rx.PyGraph: """Convert a stim error model into a RustworkX graph. It assumes that the stim circuit does not contain repeat blocks. Later on repeat blocks should be handled to make this function compatible with user-defined stim circuits. Args: detectors: coordinate included as the last element for every detector in the stim detector error model """ if detectors is None: detectors = [{}] g = rx.PyGraph(multigraph=False) index_to_DecodingGraphNode = {} def skip_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: Dict): pass def handle_error(p: float, dets: List[int], frame_changes: List[int], hyperedge: Dict): if p == 0: return if len(dets) == 0: return if len(dets) == 1: dets = [dets[0], model.num_detectors] if len(dets) > 2: raise NotImplementedError( f"Error with more than 2 symptoms can't become an edge or boundary edge: {dets!r}." ) if g.has_edge(dets[0], dets[1]): edge_ind = list(g.edge_list()).index((dets[0], dets[1])) edge_data = g.edges()[edge_ind].properties old_frame_changes = g.edges()[edge_ind].fault_ids old_p = edge_data["error_probability"] # If frame changes differ, the code has distance 2; just keep whichever was first. if set(old_frame_changes) == set(frame_changes): p = p * (1 - old_p) + old_p * (1 - p) g.remove_edge(dets[0], dets[1]) if p > 0.5: p = 1 - p if p > 0: qubits = list( set(index_to_DecodingGraphNode[dets[0]].qubits).intersection( index_to_DecodingGraphNode[dets[1]].qubits ) ) edge = DecodingGraphEdge( qubits=qubits, weight=loga((1 - p) / p), fault_ids=set(frame_changes), properties={"error_probability": p}, ) g.add_edge(dets[0], dets[1], edge) hyperedge[dets[0], dets[1]] = edge def skip_detector_coords(detector_index: int, time, qubits, det_props): pass def handle_detector_coords(detector_index: int, time, qubits, det_props): node = DecodingGraphNode(index=detector_index, time=time, qubits=qubits) node.properties = det_props index_to_DecodingGraphNode[detector_index] = node g.add_node(node) hyperedges = [] iter_flatten_model( model, handle_error=skip_error, handle_detector_coords=handle_detector_coords, detectors=detectors, hyperedges=hyperedges, ) trivial_boundary_node = DecodingGraphNode(index=model.num_detectors, time=0, is_boundary=True) g.add_node(trivial_boundary_node) index_to_DecodingGraphNode[model.num_detectors] = trivial_boundary_node iter_flatten_model( model, handle_error=handle_error, handle_detector_coords=skip_detector_coords, detectors=detectors, hyperedges=hyperedges, ) return g, hyperedges def string2nodes_with_detectors( string: str, detectors: List[Dict], logicals: List[Dict], clbits: QuantumCircuit.clbits, det_ref_values: Union[List, int] = 0, **kwargs, ): """ Convert output string from circuits into a set of nodes for `DecodingGraph`. Args: string (string): Results string to convert. detectors: A list of measurement comparisons. A measurement comparison (detector) is either a list of measurements given by a the name and index of the classical bit or a list of dictionaries, with a mandatory clbits key containing the classical bits. A dictionary can contain keys like 'qubits', 'time', 'basis' etc. logicals: A list of logical measurements. A logical measurement is a list of classical bits whose total parity is the logical eigenvalue. Again it can be a list of dictionaries. clbits: classical bits of the qiskit circuit, needed to identify measurements in the output string det_ref_values: Reference value for the detector outcomes, 0 by default kwargs (dict): Any additional keyword arguments. logical (str): Logical value whose results are used ('0' as default). all_logicals (bool): Whether to include logical nodes irrespective of value. (False as default). """ output_bits = np.array([int(char) for char in string.replace(" ", "")[::-1]]) clbit_dict = {(clbit._register.name, clbit._index): clind for clind, clbit in enumerate(clbits)} if isinstance(det_ref_values, int): det_ref_values = [det_ref_values] * len(detectors) nodes = [] for ind, det in enumerate(detectors): det = det.copy() outcomes = [clbit_dict[clbit_key] for clbit_key in det.pop("clbits")] if sum(output_bits[outcomes]) % 2 != det_ref_values[ind]: node = DecodingGraphNode(time=det.pop("time"), qubits=det.pop("qubits"), index=ind) node.properties = det nodes.append(node) log_nodes = string2rawlogicals_with_detectors( string=string, logicals=logicals, clbits=clbits, start_ind=len(detectors), **kwargs ) for node in log_nodes: nodes.append(node) return nodes def string2rawlogicals_with_detectors( string: str, logicals: List[Dict], clbits: QuantumCircuit.clbits, start_ind: int = 0, **kwargs, ): """ Convert output string from circuits into raw logical values. """ all_logicals = kwargs.get("all_logicals") logical = kwargs.get("logical") if logical is None: logical = "0" output_bits = np.array([int(char) for char in string.replace(" ", "")[::-1]]) clbit_dict = {(clbit._register.name, clbit._index): clind for clind, clbit in enumerate(clbits)} nodes = [] for index, logical_op in enumerate(logicals, start=start_ind): logical_out = 0 for q in logical_op["clbits"]: qind = clbit_dict[q] logical_out += output_bits[qind] logical_out = logical_out % 2 if all_logicals or str(logical_out) != logical: node = DecodingGraphNode( is_logical=True, is_boundary=True, index=index, ) if "qubits" in logical_op: node.qubits = logical_op["qubits"] else: node.qubits = [] nodes.append(node) return nodes def string2logical_meas( string: str, outcomes_in_logical: List[Dict], clbits: QuantumCircuit.clbits, ): """ Args: string (string): Results string from qiskit circuit outcomes_in_logical: the detector-style logical outcome clbits: classical bits of the qiskit circuit, needed to identify measurements in the output string """ output_bits = np.array([int(char) for char in string.replace(" ", "")[::-1]]) clbit_dict = {(clbit._register.name, clbit._index): clind for clind, clbit in enumerate(clbits)} log_outs = [] for logical_op in outcomes_in_logical: logical_out = 0 for q in logical_op["clbits"]: qind = clbit_dict[q] logical_out += output_bits[qind] logical_out = logical_out % 2 log_outs.append(logical_out) return log_outs
[docs] def noisify_circuit(circuits: Union[List, QuantumCircuit], noise_model: PauliNoiseModel): """ Inserts error operations into a circuit according to a pauli noise model. Handles idling errors in the form of custom gates "idle_#" which are assumed to encode the identity gate only. qc = QuantumCircuit(1, name='idle_1') qc.i(0) idle_1 = qc.to_instruction() Args: circuits: Circuit or list thereof to which noise is added. noise_model: Pauli noise model used to define types of errors to add to circuit. Returns: noisy_circuits: Corresponding circuit or list thereof. """ single_circuit = isinstance(circuits, QuantumCircuit) if single_circuit: circuits = [circuits] # create pauli errors for all errors in noise model errors = {} for g, noise in noise_model.to_dict().items(): paulis = [pauli.upper() for pauli in noise["chan"].keys()] probs = list(noise["chan"].values()) errors[g] = pauli_error(list(zip(paulis, probs))) noisy_circuits = [] for qc in circuits: noisy_qc = QuantumCircuit() for qreg in qc.qregs: noisy_qc.add_register(qreg) for creg in qc.cregs: noisy_qc.add_register(creg) for gate in qc: g = gate[0].name qubits = gate[1] pre_error = g == "reset" # add gate if it needs to go before the error if pre_error: noisy_qc.append(gate) # then the error if g in errors: noisy_qc.append(errors[g], qubits) # add gate if it needs to go after the error if not pre_error: if not g.startswith("idle_"): noisy_qc.append(gate) noisy_circuits.append(noisy_qc) if single_circuit: noisy_circuits = noisy_circuits[0] return noisy_circuits