Source code for qiskit_experiments.data_processing.sklearn_discriminators
# This code is part of Qiskit.## (C) Copyright IBM 2022.## This code is licensed under the Apache License, Version 2.0. You may# obtain a copy of this license in the LICENSE.txt file in the root directory# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.## Any modifications or derivative works of this code must retain this# copyright notice, and modified files need to carry a notice indicating# that they have been altered from the originals."""Discriminators that wrap SKLearn."""fromtypingimportAny,List,Dict,TYPE_CHECKINGfromqiskit_experiments.data_processing.discriminatorimportBaseDiscriminatorfromqiskit_experiments.framework.package_depsimportHAS_SKLEARNifTYPE_CHECKING:fromsklearn.discriminant_analysisimport(LinearDiscriminantAnalysis,QuadraticDiscriminantAnalysis,)
[docs]classSkLDA(BaseDiscriminator):"""A wrapper for the scikit-learn linear discriminant analysis. .. note:: This class requires that scikit-learn is installed. """def__init__(self,lda:"LinearDiscriminantAnalysis"):""" Args: lda: The sklearn linear discriminant analysis. This may be a trained or an untrained discriminator. Raises: DataProcessorError: If SKlearn could not be imported. """self._lda=ldaself.attributes=["coef_","intercept_","covariance_","explained_variance_ratio_","means_","priors_","scalings_","xbar_","classes_","n_features_in_","feature_names_in_",]@propertydefdiscriminator(self)->Any:"""Return then SKLearn object."""returnself._lda
[docs]defis_trained(self)->bool:"""Return True if the discriminator has been trained on data."""returnnotgetattr(self._lda,"classes_",None)isNone
[docs]defpredict(self,data:List):"""Wrap the predict method of the LDA."""returnself._lda.predict(data)
[docs]deffit(self,data:List,labels:List):"""Fit the LDA. Args: data: The independent data. labels: The labels corresponding to data. """self._lda.fit(data,labels)
[docs]defconfig(self)->Dict[str,Any]:"""Return the configuration of the LDA."""attr_conf={attr:getattr(self._lda,attr,None)forattrinself.attributes}return{"params":self._lda.get_params(),"attributes":attr_conf}
[docs]@classmethod@HAS_SKLEARN.require_in_calldeffrom_config(cls,config:Dict[str,Any])->"SkLDA":"""Deserialize from an object."""fromsklearn.discriminant_analysisimportLinearDiscriminantAnalysislda=LinearDiscriminantAnalysis()lda.set_params(**config["params"])forname,valueinconfig["attributes"].items():ifvalueisnotNone:setattr(lda,name,value)returnSkLDA(lda)
[docs]classSkQDA(BaseDiscriminator):"""A wrapper for the SKlearn quadratic discriminant analysis. .. note:: This class requires that scikit-learn is installed. """def__init__(self,qda:"QuadraticDiscriminantAnalysis"):""" Args: qda: The sklearn quadratic discriminant analysis. This may be a trained or an untrained discriminator. Raises: DataProcessorError: If SKlearn could not be imported. """self._qda=qdaself.attributes=["coef_","intercept_","covariance_","explained_variance_ratio_","means_","priors_","scalings_","xbar_","classes_","n_features_in_","feature_names_in_","rotations_",]@propertydefdiscriminator(self)->Any:"""Return then SKLearn object."""returnself._qda
[docs]defis_trained(self)->bool:"""Return True if the discriminator has been trained on data."""returnnotgetattr(self._qda,"classes_",None)isNone
[docs]defpredict(self,data:List):"""Wrap the predict method of the QDA."""returnself._qda.predict(data)
[docs]deffit(self,data:List,labels:List):"""Fit the QDA. Args: data: The independent data. labels: The labels corresponding to data. """self._qda.fit(data,labels)
[docs]defconfig(self)->Dict[str,Any]:"""Return the configuration of the QDA."""attr_conf={attr:getattr(self._qda,attr,None)forattrinself.attributes}return{"params":self._qda.get_params(),"attributes":attr_conf}
[docs]@classmethod@HAS_SKLEARN.require_in_calldeffrom_config(cls,config:Dict[str,Any])->"SkQDA":"""Deserialize from an object."""fromsklearn.discriminant_analysisimportQuadraticDiscriminantAnalysisqda=QuadraticDiscriminantAnalysis()qda.set_params(**config["params"])forname,valueinconfig["attributes"].items():ifvalueisnotNone:setattr(qda,name,value)returnSkQDA(qda)