# This code is part of Qiskit.
# (C) Copyright IBM 2022.
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Multi state discrimination analysis."""

from typing import List, Tuple, TYPE_CHECKING

import matplotlib
import numpy as np

from qiskit.providers.options import Options
from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, ExperimentData
from qiskit_experiments.data_processing import SkQDA
from qiskit_experiments.visualization import BasePlotter, IQPlotter, MplDrawer, PlotStyle
from qiskit_experiments.framework.package_deps import HAS_SKLEARN

    from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis

[docs] class MultiStateDiscriminationAnalysis(BaseAnalysis): r"""This class fits a multi-state discriminator to the data. The class will report the configuration of the discriminator in the analysis result as well as the fidelity of the discrimination reported as .. math:: F = 1 - \frac{1}{d}\sum{i\neq j}P(i|j) Here, :math:`d` is the number of levels that were discriminated while :math:`P(i|j)` is the probability of measuring outcome :math:`i` given that state :math:`j` was prepared. .. note:: This class requires that scikit-learn is installed. """ @classmethod @HAS_SKLEARN.require_in_call def _default_options(cls) -> Options: """Return default analysis options. Analysis Options: plot (bool): Set ``True`` to create figure for fit result. plotter (BasePlotter): A plotter instance to visualize the analysis result. ax (AxesSubplot): Optional. A matplotlib axis object in which to draw. discriminator (BaseDiscriminator): The sklearn discriminator to classify the data. The default is a quadratic discriminant analysis. """ options = super()._default_options() options.plotter = IQPlotter(MplDrawer()) options.plotter.set_options( discriminator_max_resolution=64, style=PlotStyle(figsize=(6, 4), legend_loc=None), ) options.plot = True = None from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis options.discriminator = SkQDA(QuadraticDiscriminantAnalysis()) return options @property def plotter(self) -> BasePlotter: """A short-cut to the IQ plotter instance.""" return self._options.plotter def _run_analysis( self, experiment_data: ExperimentData, ) -> Tuple[List[AnalysisResultData], List["matplotlib.figure.Figure"]]: """Train a discriminator based on the experiment data. Args: experiment_data: the data obtained from the experiment Returns: The configuration of the trained discriminator and the IQ plot and the fidelity of the discrimination. """ # number of states and shots n_states = len( num_shots = len([0]["memory"]) # Process the data and get labels data, fit_state = [], [] for i in range(n_states): state_data = [] for j in range(num_shots): state_data.append([i]["memory"][j][0]) data.append(np.array(state_data)) fit_state.append([i]["metadata"]["label"]) # Train a discriminator on the processed data discriminator = self.options.discriminator np.concatenate(data), np.asarray([[label] * num_shots for label in fit_state]).flatten().transpose(), ) # Calculate fidelity. First we need to calculate P(i|j):= prob. measuring outcome i given # state j was prepared predicted_data = [discriminator.predict(state_data) for state_data in data] # count per prepared state the number of measured states of each kind and calculate the # probability of measuring the wrong state prob_wrong = 0 for i in range(n_states): counts = [0] * n_states for point in predicted_data[i]: counts[point] += 1 for j in range(n_states): if j != i: prob_wrong += counts[j] / num_shots # calculate the fidelity fidelity = 1 - (1 / n_states) * prob_wrong # Crate analysis results from the discriminator configuration analysis_results = [ AnalysisResultData(name="discriminator_config", value=discriminator.config()), AnalysisResultData(name="fidelity", value=fidelity), ] figures = [] if self.options.plot: figures.append(self._levels_plot(discriminator, data, fit_state, fidelity).get_figure()) return analysis_results, figures def _levels_plot(self, discriminator, data, fit_state, fidelity) -> matplotlib.figure.Figure: """Helper function for plotting IQ plane for different energy levels. Args: discriminator: the trained discriminator data: the training data fit_state: the labels fidelity: the fidelity of the classification Returns: The plotted IQ data. """ # create figure labels params_dict = {} for state in fit_state: params_dict[state] = {"label": f"$|{state}\\rangle$"} # Update params_dict to contain any existing series_params values, # where they have priority over params_dict. params_dict.update(self.plotter.figure_options.series_params) self.plotter.set_figure_options(series_params=params_dict) # calculate centroids centroids = [np.mean(x, axis=0) for x in data] for p, c, n in zip(data, centroids, fit_state): self.plotter.set_series_data(n, points=p, centroid=c) self.plotter.set_supplementary_data(discriminator=discriminator, fidelity=fidelity) return self.plotter.figure()