Source code for qiskit_experiments.library.characterization.analysis.multi_state_discrimination_analysis
# This code is part of Qiskit.## (C) Copyright IBM 2022.## This code is licensed under the Apache License, Version 2.0. You may# obtain a copy of this license in the LICENSE.txt file in the root directory# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.## Any modifications or derivative works of this code must retain this# copyright notice, and modified files need to carry a notice indicating# that they have been altered from the originals."""Multi state discrimination analysis."""fromtypingimportList,Tuple,TYPE_CHECKINGimportmatplotlibimportnumpyasnpfromqiskit.providers.optionsimportOptionsfromqiskit_experiments.frameworkimportBaseAnalysis,AnalysisResultData,ExperimentDatafromqiskit_experiments.data_processingimportSkQDAfromqiskit_experiments.visualizationimportBasePlotter,IQPlotter,MplDrawer,PlotStylefromqiskit_experiments.framework.package_depsimportHAS_SKLEARNifTYPE_CHECKING:fromsklearn.discriminant_analysisimportQuadraticDiscriminantAnalysis
[docs]classMultiStateDiscriminationAnalysis(BaseAnalysis):r"""This class fits a multi-state discriminator to the data. The class will report the configuration of the discriminator in the analysis result as well as the fidelity of the discrimination reported as .. math:: F = 1 - \frac{1}{d}\sum{i\neq j}P(i|j) Here, :math:`d` is the number of levels that were discriminated while :math:`P(i|j)` is the probability of measuring outcome :math:`i` given that state :math:`j` was prepared. .. note:: This class requires that scikit-learn is installed. """@classmethod@HAS_SKLEARN.require_in_calldef_default_options(cls)->Options:"""Return default analysis options. Analysis Options: plot (bool): Set ``True`` to create figure for fit result. plotter (BasePlotter): A plotter instance to visualize the analysis result. ax (AxesSubplot): Optional. A matplotlib axis object in which to draw. discriminator (BaseDiscriminator): The sklearn discriminator to classify the data. The default is a quadratic discriminant analysis. """options=super()._default_options()options.plotter=IQPlotter(MplDrawer())options.plotter.set_options(discriminator_max_resolution=64,style=PlotStyle(figsize=(6,4),legend_loc=None),)options.plot=Trueoptions.ax=Nonefromsklearn.discriminant_analysisimportQuadraticDiscriminantAnalysisoptions.discriminator=SkQDA(QuadraticDiscriminantAnalysis())returnoptions@propertydefplotter(self)->BasePlotter:"""A short-cut to the IQ plotter instance."""returnself._options.plotterdef_run_analysis(self,experiment_data:ExperimentData,)->Tuple[List[AnalysisResultData],List["matplotlib.figure.Figure"]]:"""Train a discriminator based on the experiment data. Args: experiment_data: the data obtained from the experiment Returns: The configuration of the trained discriminator and the IQ plot and the fidelity of the discrimination. """# number of states and shotsn_states=len(experiment_data.data())num_shots=len(experiment_data.data()[0]["memory"])# Process the data and get labelsdata,fit_state=[],[]foriinrange(n_states):state_data=[]forjinrange(num_shots):state_data.append(experiment_data.data()[i]["memory"][j][0])data.append(np.array(state_data))fit_state.append(experiment_data.data()[i]["metadata"]["label"])# Train a discriminator on the processed datadiscriminator=self.options.discriminatordiscriminator.fit(np.concatenate(data),np.asarray([[label]*num_shotsforlabelinfit_state]).flatten().transpose(),)# Calculate fidelity. First we need to calculate P(i|j):= prob. measuring outcome i given# state j was preparedpredicted_data=[discriminator.predict(state_data)forstate_dataindata]# count per prepared state the number of measured states of each kind and calculate the# probability of measuring the wrong stateprob_wrong=0foriinrange(n_states):counts=[0]*n_statesforpointinpredicted_data[i]:counts[point]+=1forjinrange(n_states):ifj!=i:prob_wrong+=counts[j]/num_shots# calculate the fidelityfidelity=1-(1/n_states)*prob_wrong# Crate analysis results from the discriminator configurationanalysis_results=[AnalysisResultData(name="discriminator_config",value=discriminator.config()),AnalysisResultData(name="fidelity",value=fidelity),]figures=[]ifself.options.plot:figures.append(self._levels_plot(discriminator,data,fit_state,fidelity).get_figure())returnanalysis_results,figuresdef_levels_plot(self,discriminator,data,fit_state,fidelity)->matplotlib.figure.Figure:"""Helper function for plotting IQ plane for different energy levels. Args: discriminator: the trained discriminator data: the training data fit_state: the labels fidelity: the fidelity of the classification Returns: The plotted IQ data. """# create figure labelsparams_dict={}forstateinfit_state:params_dict[state]={"label":f"$|{state}\\rangle$"}# Update params_dict to contain any existing series_params values,# where they have priority over params_dict.params_dict.update(self.plotter.figure_options.series_params)self.plotter.set_figure_options(series_params=params_dict)# calculate centroidscentroids=[np.mean(x,axis=0)forxindata]forp,c,ninzip(data,centroids,fit_state):self.plotter.set_series_data(n,points=p,centroid=c)self.plotter.set_supplementary_data(discriminator=discriminator,fidelity=fidelity)returnself.plotter.figure()