Source code for qiskit_qec.decoders.hdrg_decoders

# -*- coding: utf-8 -*-

# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=invalid-name

"""Hard decision renormalization group decoders."""

from abc import ABC
from copy import copy
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple

from rustworkx import PyGraph, connected_components, distance_matrix
import numpy as np

from qiskit_qec.decoders.decoding_graph import DecodingGraph
from qiskit_qec.utils import DecodingGraphEdge


class ClusteringDecoder(ABC):
    """
    Generic base class for clustering decoders.
    """

    def __init__(
        self,
        code_circuit,
        decoding_graph: DecodingGraph = None,
    ):
        self.code = code_circuit

        if hasattr(self.code, "code_index"):
            self.code_index = self.code.code_index
        else:
            self.code_index = {j: j for j in range(self.code.n)}

        if decoding_graph:
            self.decoding_graph = decoding_graph
        else:
            self.decoding_graph = DecodingGraph(self.code)

    def get_corrections(self, string, clusters):
        """
        Turn a set of neutral clusters into corrections.

        Args:
            string (str): Output string of the code
            clusters (dict): Dictionary with the indices of the given node
            as keys and an integer specifying their cluster as the corresponding
            value.
        Returns:
            corrected_logicals (list): A list of integers that are 0 or 1.
        These are the corrected values of the final transversal
        measurement, in the same form as given by the code's `string2raw_logicals`.
        """

        # get the list of bulk nodes for each cluster
        cluster_nodes = {c: [] for c in clusters.values()}
        for n, c in clusters.items():
            node = self.decoding_graph.graph[n]
            if not node.is_logical:
                cluster_nodes[c].append(node)

        # get the list of required logicals for each cluster
        cluster_logicals = {}
        for c, nodes in cluster_nodes.items():
            _, logical_nodes, _ = self.code.check_nodes(nodes, minimal=True)
            log_indexes = [node.index for node in logical_nodes]
            cluster_logicals[c] = log_indexes

        # get the net effect on each logical
        net_logicals = {node.index: 0 for node in self.decoding_graph.logical_nodes}
        for c, log_indexes in cluster_logicals.items():
            for log_index in log_indexes:
                net_logicals[log_index] += 1
        for log_index, num in net_logicals.items():
            net_logicals[log_index] = num % 2

        corrected_logicals = self.code.string2raw_logicals(string)
        for log_index, log_value in enumerate(corrected_logicals):
            corrected_logicals[log_index] = (net_logicals[log_index] + int(log_value)) % 2

        return corrected_logicals


class BravyiHaahDecoder(ClusteringDecoder):
    """Decoder based on finding connected components within the decoding graph."""

    def __init__(
        self,
        code_circuit,
        decoding_graph: DecodingGraph = None,
    ):
        super().__init__(code_circuit, decoding_graph)
        self._distance = distance_matrix(self.decoding_graph.graph)

    def _cluster(self, ns, dist_max):
        """
        Finds connected components in the given nodes, for nodes connected by at most the given distance
        in the given decoding graph.
        """

        dg = self.decoding_graph.graph
        # create empty `DecodingGraph`
        cluster_graph = DecodingGraph(None)
        cg = cluster_graph.graph
        # add all the given nodes to cg
        d2c = {}
        c2g = {}
        for n in ns:
            node = dg.nodes()[n]
            d2c[n] = cg.add_node(node)
            c2g[d2c[n]] = n
        # add an edge between a pair of the given nodes if their distance is small enough
        for n0 in ns:
            for n1 in ns:
                if n0 < n1:
                    dist = self._distance[n0, n1]
                    if dist <= dist_max:
                        cg.add_edge(d2c[n0], d2c[n1], {"distance": dist})
        # find the connected components of cg
        con_comps = connected_components(cg)

        # use these to define clusters
        clusters = {}
        con_comp_dict = {}
        for c, con_comp in enumerate(con_comps):
            con_comp_dict[c] = []

            # check the neutrality of each connected component
            con_nodes = [cg[n] for n in con_comp]
            neutral, logicals, num_errors = self.code.check_nodes(con_nodes, ignore_extras=True)

            # it's fully neutral if no extra logicals are needed
            # and if the error num is less than the max dist
            fully_neutral = neutral and logicals == []
            if num_errors:
                fully_neutral = fully_neutral and num_errors < dist_max

            # if a cluster is neutral, all nodes are labelled with c
            # otherwise, it gets a None
            for n in con_comp:
                if fully_neutral:
                    clusters[c2g[n]] = c
                else:
                    clusters[c2g[n]] = None
                con_comp_dict[c].append(c2g[n])

        return clusters, con_comp_dict

    def cluster(self, nodes):
        """

        Args:
            nodes (list): List of nodes, of the type produced by `string2nodes`.
        Returns:
            final_clusters (dict): Dictionary with the indices of the given node
            as keys and an integer specifying their cluster as the corresponding
            value.
        """

        # get indices for nodes and logical nodes
        dg = self.decoding_graph.graph
        ns = set(dg.nodes().index(node) for node in nodes)
        lns = set(dg.nodes().index(node) for node in self.decoding_graph.logical_nodes)

        dist_max = 0
        final_clusters = {}
        con_comps = []
        clusterss = []
        while ns and dist_max <= self.code.d:
            dist_max += 1
            # add logical nodes to unpaired nodes
            ns = set(ns).union(lns)

            # cluster nodes and contract decoding graph given the current distance
            clusters, con_comp = self._cluster(ns, dist_max)
            # record the clustered and unclustered nodes
            ns = []
            for n, c in clusters.items():
                if c is not None:
                    final_clusters[n] = c
                else:
                    if not dg[n].is_logical:
                        ns.append(n)
            con_comps.append(con_comp)
            clusterss.append(clusters)

        return final_clusters

    def process(self, string, predecoder=None):
        """
        Process an output string and return corrected final outcomes.

        Args:
            string (str): Output string of the code.
            predecoder (callable): Function that takes in and returns
            a list of nodes. Used to do preprocessing on the nodes
            corresponding to the input string.

        Returns:
            corrected_logicals (list): A list of integers that are 0 or 1.
        These are the corrected values of the final transversal
        measurement, in the same form as given by the code's `string2raw_logicals`.
        """

        # turn string into nodes and cluster
        nodes = self.code.string2nodes(string, all_logicals=True)
        # apply predecoder if one is given
        if predecoder:
            nodes = predecoder(nodes)
        # then cluster
        clusters = self.cluster(nodes)

        return self.get_corrections(string, clusters)


@dataclass
class SpanningForest:
    """
    Spanning forest for the peeling decoder.
    """

    vertices: Dict[int, List[int]]
    edges: List[int]


@dataclass
class BoundaryEdge:
    """
    Boundary edge for the boundary of a UnionFindDecoderCluster.
    """

    index: int
    cluster_vertex: int
    neighbour_vertex: int
    data: DecodingGraphEdge

    def reverse(self):
        """
        Returns a reversed version of the boundary edge (cluster and neighbour vertex flipped)
        """
        return BoundaryEdge(
            index=self.index,
            cluster_vertex=self.neighbour_vertex,
            neighbour_vertex=self.cluster_vertex,
            data=self.data,
        )


@dataclass
class UnionFindDecoderCluster:
    """
    Cluster for the UnionFindDecoder
    """

    boundary: List[BoundaryEdge]
    atypical_nodes: Set[int]
    boundary_nodes: Set[int]
    nodes: Set[int]
    fully_grown_edges: Set[int]
    edge_support: Set[Tuple[int]]
    size: int


@dataclass
class FusionEntry:
    """
    Entry for the fusion list between the growing and merging of the union find decoder.
    """

    u: int
    v: int
    connecting_edge: BoundaryEdge


[docs] class UnionFindDecoder(ClusteringDecoder): """ Decoder based on growing clusters around syndrome errors to "convert" them into erasure errors, which can be corrected easily, by the peeling decoder for compatible codes or by the standard HDRG method in general. To avoid using the peeling decoder, and instead use the standard method for clustering decoders to get corrections, set `use_peeling=False`. Growth unit is 0.5 by default, but can be changed with `growth_unit`. To use half the minimum boundarye edge weight for each clustering round, set `growth_unit=None`. """ def __init__( self, code, decoding_graph: DecodingGraph = None, use_peeling=True, use_is_cluster_neutral=False, growth_unit=0.5, ) -> None: super().__init__(code, decoding_graph=decoding_graph) self.graph = self.decoding_graph.graph self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots: List[int] = [] self.use_peeling = use_peeling self.use_is_cluster_neutral = use_is_cluster_neutral self._clusters4peeling = [] self.growth_unit = growth_unit self._growth_unit = None
[docs] def process(self, string: str, predecoder=None): """ Process an output string and return corrected final outcomes. Args: string (str): Output string of the code. predecoder (callable): Function that takes in and returns a list of nodes. Used to do preprocessing on the nodes corresponding to the input string. Returns: corrected_logicals (list): A list of integers that are 0 or 1. These are the corrected values of the final logical measurement. """ if self.use_peeling: highlighted_nodes = self.code.string2nodes(string, all_logicals=True) if predecoder: highlighted_nodes = predecoder(highlighted_nodes) # call cluster to do the clustering, but actually use the peeling form self.cluster(highlighted_nodes) clusters = self._clusters4peeling # determine the net logical z measured_logicals = {} for node in self.decoding_graph.logical_nodes: measured_logicals[node.index] = node.qubits net_z_logicals = {tuple(z_logical): 0 for z_logical in measured_logicals.values()} for cluster_nodes, _ in clusters: erasure = self.graph.subgraph(cluster_nodes) flipped_qubits = self.peeling(erasure) for qubit_to_be_corrected in flipped_qubits: for z_logical in net_z_logicals: if qubit_to_be_corrected in z_logical: net_z_logicals[z_logical] += 1 for z_logical, num in net_z_logicals.items(): net_z_logicals[z_logical] = num % 2 # apply this to the raw readout corrected_z_logicals = [] raw_logicals = self.code.string2raw_logicals(string) for j, z_logical in measured_logicals.items(): raw_logical = int(raw_logicals[j]) corrected_logical = (raw_logical + net_z_logicals[tuple(z_logical)]) % 2 corrected_z_logicals.append(corrected_logical) return corrected_z_logicals else: # turn string into nodes and cluster nodes = self.code.string2nodes(string, all_logicals=True) if predecoder: nodes = predecoder(nodes) clusters = self.cluster(nodes) return self.get_corrections(string, clusters)
[docs] def cluster(self, nodes: List): """ Create clusters using the union-find algorithm. Args: nodes (List): List of non-typical nodes in the syndrome graph, of the type produced by `string2nodes`. Returns: clusters (dict): Dictionary with the indices of the given node as keys and an integer specifying their cluster as the corresponding value. """ if self.growth_unit: self._growth_unit = self.growth_unit else: self._growth_unit = 0 node_indices = [self.decoding_graph.node_index(node) for node in nodes] for node_index in self.graph.node_indexes(): self.graph[node_index].properties["syndrome"] = node_index in node_indices self.graph[node_index].properties["root"] = node_index for edge in self.graph.edges(): edge.properties["growth"] = 0 edge.properties["fully_grown"] = False self.clusters: Dict[int, UnionFindDecoderCluster] = {} self.odd_cluster_roots = [] for node_index in node_indices: self._create_new_cluster(node_index) j = 0 while self.odd_cluster_roots and j < 2 * self.code.d * (self.code.T + 1): self._grow_and_merge_clusters() j += 1 # compile info into standard clusters dict clusters = {} for c, cluster in self.clusters.items(): # determine which nodes exactly are in the neutral cluster neutral_nodes = list(cluster.atypical_nodes | cluster.boundary_nodes) # put them in the required dict for n in neutral_nodes: clusters[n] = c # also compile into form required for peeling self._clusters4peeling = [] for _, cluster in self.clusters.items(): if not cluster.atypical_nodes: continue self._clusters4peeling.append( (list(cluster.nodes), list(cluster.atypical_nodes | cluster.boundary_nodes)) ) return clusters
[docs] def find(self, u: int) -> int: """ Find() function as described in the paper that returns the root of the cluster of a node, including path compression. Args: u (int): The index of the node in the decoding graph. Returns: root (int): The root of the cluster of node u. """ if self.graph[u].properties["root"] == u: return self.graph[u].properties["root"] self.graph[u].properties["root"] = self.find(self.graph[u].properties["root"]) return self.graph[u].properties["root"]
def _create_new_cluster(self, node_index): node = self.graph[node_index] if not node.is_logical: self.odd_cluster_roots.insert(0, node_index) boundary_edges = [] for edge_index, neighbour, data in self.neighbouring_edges(node_index): boundary_edges.append(BoundaryEdge(edge_index, node_index, neighbour, copy(data))) self.clusters[node_index] = UnionFindDecoderCluster( boundary=boundary_edges, fully_grown_edges=set(), edge_support=set(), atypical_nodes=set([node_index]) if not node.is_logical else set([]), boundary_nodes=set([node_index]) if node.is_logical else set([]), nodes=set([node_index]), size=1, ) def _grow_and_merge_clusters(self) -> Set[int]: fusion_edge_list = self._grow_clusters() self._merge_clusters(fusion_edge_list) def _grow_clusters(self) -> List[FusionEntry]: """ Grow every "odd" cluster by half an edge. Returns: fusion_edge_list (List[FusionEntry]): List of edges that connect two clusters that will be merged in the next step. """ fusion_edge_list: List[FusionEntry] = [] if not self.growth_unit: min_weight = np.inf for root in self.odd_cluster_roots: cluster = self.clusters[root] for edge in cluster.boundary: min_weight = max(min(min_weight, edge.data.weight), 1e-6) self._growth_unit = min_weight / 2 for root in self.odd_cluster_roots: cluster = self.clusters[root] for edge in cluster.boundary: edge.data.properties["growth"] += self._growth_unit if ( edge.data.properties["growth"] >= edge.data.weight and not edge.data.properties["fully_grown"] ): neighbour_root = self.find(edge.neighbour_vertex) if not neighbour_root in self.clusters: boundary_edges = [] for edge_index, neighbour_neighbour, data in self.neighbouring_edges( edge.neighbour_vertex ): boundary_edges.append( BoundaryEdge( edge_index, edge.neighbour_vertex, neighbour_neighbour, data ) ) self.graph[edge.neighbour_vertex].properties["root"] = edge.neighbour_vertex self.clusters[edge.neighbour_vertex] = UnionFindDecoderCluster( boundary=boundary_edges, fully_grown_edges=set(), edge_support=set(), atypical_nodes=set(), boundary_nodes=( set([edge.neighbour_vertex]) if self.graph[edge.neighbour_vertex].is_logical else set([]) ), nodes=set([edge.neighbour_vertex]), size=1, ) fusion_entry = FusionEntry( u=edge.cluster_vertex, v=edge.neighbour_vertex, connecting_edge=edge ) fusion_edge_list.append(fusion_entry) return fusion_edge_list def _merge_clusters(self, fusion_edge_list: List[FusionEntry]): """ Merges the clusters based on the fusion_edge_list computed in _grow_clusters(). Updates the odd_clusters list by recomputing the neutrality of the newly merged clusters. Args: fusion_edge_list (List[FusionEntry]): List of edges that connect two clusters that was computed in _grow_clusters(). """ new_neutral_clusters = [] for entry in fusion_edge_list: root_u, root_v = self.find(entry.u), self.find(entry.v) if root_u == root_v: continue new_root = root_v if self.clusters[root_v].size > self.clusters[root_u].size else root_u root_to_update = root_v if new_root == root_u else root_u if new_root in new_neutral_clusters or root_to_update in new_neutral_clusters: continue cluster = self.clusters[new_root] other_cluster = self.clusters.pop(root_to_update) entry.connecting_edge.data.properties["growth"] = 0 entry.connecting_edge.data.properties["fully_grown"] = True cluster.fully_grown_edges.add(entry.connecting_edge.index) cluster.edge_support.add( tuple(self.graph.get_edge_data_by_index(entry.connecting_edge.index).qubits) ) # Merge boundaries cluster.boundary += other_cluster.boundary cluster.boundary.remove(entry.connecting_edge) cluster.boundary.remove(entry.connecting_edge.reverse()) cluster.nodes |= other_cluster.nodes cluster.atypical_nodes |= other_cluster.atypical_nodes cluster.boundary_nodes |= other_cluster.boundary_nodes cluster.fully_grown_edges |= other_cluster.fully_grown_edges cluster.edge_support |= other_cluster.edge_support cluster.size += other_cluster.size # see if the cluster is neutral and update odd_cluster_roots accordingly fully_neutral = False if self._growth_unit: # assume non-neutral while growing along 0-weight edges for nodes in [ [self.graph[node] for node in cluster.atypical_nodes], [ self.graph[node] for node in cluster.atypical_nodes | ( set(list(cluster.boundary_nodes)[:1]) if cluster.boundary_nodes else set() ) ], ]: if self.use_is_cluster_neutral: fully_neutral = self.code.is_cluster_neutral(nodes) else: neutral, extras, num = self.code.check_nodes(nodes) for node in extras: neutral = neutral and (not node.is_boundary) neutral = neutral and num <= len(cluster.edge_support) fully_neutral = fully_neutral or neutral if fully_neutral: if new_root in self.odd_cluster_roots: self.odd_cluster_roots.remove(new_root) new_neutral_clusters.append(new_root) else: if not new_root in self.odd_cluster_roots: self.odd_cluster_roots.append(new_root) if root_to_update in self.odd_cluster_roots: self.odd_cluster_roots.remove(root_to_update) self.graph[root_to_update].properties["root"] = new_root self.odd_cluster_roots = sorted( self.odd_cluster_roots, key=lambda c: self.clusters[c].size )
[docs] def peeling(self, erasure: PyGraph) -> List[int]: """ " Runs the peeling decoder on the erasure provided. Assumes that the erasure is one connected component, if not it will run in an infinite loop in the tree construction. It works by first producing a spanning forest of the erasure and then going backwards through the edges of the tree computing the error based on the syndrome. Based on arXiv:1703.01517. Args: erasure (PyGraph): subgraph of the syndrome graph that represents the erasure. Returns: errors (List[int]): List of qubit indices on which Pauli errors occurred. """ tree = SpanningForest(vertices={}, edges=[]) # Construct spanning forest # Pick starting vertex for vertex in erasure.node_indices(): if erasure[vertex].is_logical and erasure[vertex].properties["syndrome"]: tree.vertices[vertex] = [] break if not tree.vertices: for vertex in erasure.node_indices(): if erasure[vertex].properties["syndrome"]: tree.vertices[vertex] = [] break # Expand forest |V| - 1 times, constructing it while len(tree.edges) < len(erasure.nodes()) - 1: vertices = copy(tree.vertices) for node in vertices.keys(): if len(tree.edges) >= len(erasure.nodes()) - 1: break for edge, (_, neighbour, _) in dict(erasure.incident_edge_index_map(node)).items(): if not neighbour in tree.vertices.keys(): tree.edges.append(edge) tree.vertices[neighbour] = [] tree.vertices[node].append(edge) break edges = set() for edge in tree.edges[::-1]: endpoints = erasure.get_edge_endpoints_by_index(edge) pendant_vertex = endpoints[0] if not tree.vertices[endpoints[0]] else endpoints[1] tree_vertex = endpoints[0] if pendant_vertex == endpoints[1] else endpoints[1] tree.vertices[tree_vertex].remove(edge) if erasure[pendant_vertex].properties["syndrome"]: edges.add(edge) erasure[tree_vertex].properties["syndrome"] = not erasure[tree_vertex].properties[ "syndrome" ] erasure[pendant_vertex].properties["syndrome"] = False return [erasure.edges()[edge].qubits[0] for edge in edges if erasure.edges()[edge].qubits]
[docs] def neighbouring_edges(self, node_index) -> List[Tuple[int, int, DecodingGraphEdge]]: """Returns all of the neighbouring edges of a node in the decoding graph. Args: node_index (int): The index of the node in the graph. Returns: neighbouring_edges (List[Tuple[int, int, DecodingGraphEdge]]): List of neighbouring edges In following format:: { index of edge in graph, index of neighbour node in graph, data payload of the edge } """ return [ (edge, neighbour, data) for edge, (_, neighbour, data) in dict( self.graph.incident_edge_index_map(node_index) ).items() ]