Source code for qiskit_experiments.library.characterization.analysis.multi_state_discrimination_analysis
# This code is part of Qiskit.
#
# (C) Copyright IBM 2022.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Multi state discrimination analysis."""
from typing import List, Tuple, TYPE_CHECKING
import matplotlib
import numpy as np
from qiskit.providers.options import Options
from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, ExperimentData
from qiskit_experiments.data_processing import SkQDA
from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle
from qiskit_experiments.warnings import HAS_SKLEARN
if TYPE_CHECKING:
    from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
[docs]class MultiStateDiscriminationAnalysis(BaseAnalysis):
    r"""This class fits a multi-state discriminator to the data.
    The class will report the configuration of the discriminator in the analysis result as well as
    the fidelity of the discrimination reported as
    .. math::
        F = 1 - \frac{1}{d}\sum{i\neq j}P(i|j)
    Here, :math:`d` is the number of levels that were discriminated while :math:`P(i|j)` is the
    probability of measuring outcome :math:`i` given that state :math:`j` was prepared.
    .. note::
        This class requires that scikit-learn is installed.
    """
    @classmethod
    @HAS_SKLEARN.require_in_call
    def _default_options(cls) -> Options:
        """Return default analysis options.
        Analysis Options:
            plot (bool): Set ``True`` to create figure for fit result.
            plotter (BasePlotter): A plotter instance to visualize the analysis result.
            ax (AxesSubplot): Optional. A matplotlib axis object in which to draw.
            discriminator (BaseDiscriminator): The sklearn discriminator to classify the data.
                The default is a quadratic discriminant analysis.
        """
        options = super()._default_options()
        options.plotter = IQPlotter(MplDrawer())
        options.plotter.set_options(
            discriminator_max_resolution=64,
            style=PlotStyle(figsize=(6, 4), legend_loc=None),
        )
        options.plot = True
        options.ax = None
        from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
        options.discriminator = SkQDA(QuadraticDiscriminantAnalysis())
        return options
    @property
    def plotter(self) -> BasePlotter:
        """A short-cut to the IQ plotter instance."""
        return self._options.plotter
    def _run_analysis(
        self,
        experiment_data: ExperimentData,
    ) -> Tuple[List[AnalysisResultData], List["matplotlib.figure.Figure"]]:
        """Train a discriminator based on the experiment data.
        Args:
            experiment_data: the data obtained from the experiment
        Returns:
            The configuration of the trained discriminator and the IQ plot and the fidelity of the
            discrimination.
        """
        # number of states and shots
        n_states = len(experiment_data.data())
        num_shots = len(experiment_data.data()[0]["memory"])
        # Process the data and get labels
        data, fit_state = [], []
        for i in range(n_states):
            state_data = []
            for j in range(num_shots):
                state_data.append(experiment_data.data()[i]["memory"][j][0])
            data.append(np.array(state_data))
            fit_state.append(experiment_data.data()[i]["metadata"]["label"])
        # Train a discriminator on the processed data
        discriminator = self.options.discriminator
        discriminator.fit(
            np.concatenate(data),
            np.asarray([[label] * num_shots for label in fit_state]).flatten().transpose(),
        )
        # Calculate fidelity. First we need to calculate P(i|j):= prob. measuring outcome i given
        # state j was prepared
        predicted_data = [discriminator.predict(state_data) for state_data in data]
        # count per prepared state the number of measured states of each kind and calculate the
        # probability of measuring the wrong state
        prob_wrong = 0
        for i in range(n_states):
            counts = [0] * n_states
            for point in predicted_data[i]:
                counts[point] += 1
            for j in range(n_states):
                if j != i:
                    prob_wrong += counts[j] / num_shots
        # calculate the fidelity
        fidelity = 1 - (1 / n_states) * prob_wrong
        # Crate analysis results from the discriminator configuration
        analysis_results = [
            AnalysisResultData(name="discriminator_config", value=discriminator.config()),
            AnalysisResultData(name="fidelity", value=fidelity),
        ]
        figures = []
        if self.options.plot:
            figures.append(self._levels_plot(discriminator, data, fit_state, fidelity).get_figure())
        return analysis_results, figures
    def _levels_plot(self, discriminator, data, fit_state, fidelity) -> matplotlib.figure.Figure:
        """Helper function for plotting IQ plane for different energy levels.
        Args:
            discriminator: the trained discriminator
            data: the training data
            fit_state: the labels
            fidelity: the fidelity of the classification
        Returns:
            The plotted IQ data.
        """
        # create figure labels
        params_dict = {}
        for state in fit_state:
            params_dict[state] = {"label": f"$|{state}\\rangle$"}
        # Update params_dict to contain any existing series_params values,
        # where they have priority over params_dict.
        params_dict.update(self.plotter.figure_options.series_params)
        self.plotter.set_figure_options(series_params=params_dict)
        # calculate centroids
        centroids = [np.mean(x, axis=0) for x in data]
        for p, c, n in zip(data, centroids, fit_state):
            self.plotter.set_series_data(n, points=p, centroid=c)
        self.plotter.set_supplementary_data(discriminator=discriminator, fidelity=fidelity)
        return self.plotter.figure()