Source code for qiskit_experiments.curve_analysis.scatter_table
# This code is part of Qiskit.## (C) Copyright IBM 2023.## This code is licensed under the Apache License, Version 2.0. You may# obtain a copy of this license in the LICENSE.txt file in the root directory# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.## Any modifications or derivative works of this code must retain this# copyright notice, and modified files need to carry a notice indicating# that they have been altered from the originals."""Table representation of the x, y data for curve fitting."""from__future__importannotationsimportloggingimportwarningsfromcollections.abcimportIteratorfromtypingimportAnyfromfunctoolsimportreducefromitertoolsimportproductimportnumpyasnpimportpandasaspdfromqiskit.utilsimportdeprecate_funcLOG=logging.getLogger(__name__)
[docs]classScatterTable:"""A table-like dataset for the intermediate data used for curve fitting. Default table columns are defined in the class attribute :attr:`.COLUMNS`. This table cannot be expanded with user-provided column names. In a standard :class:`.CurveAnalysis` subclass, a ScatterTable instance may be stored in the :class:`.ExperimentData` as an artifact. Users can retrieve the table data at a later time to rerun a fitting with a homemade program or with different fit options, or to visualize the curves in a preferred format. This table dataset is designed to seamlessly provide such information that an experimentalist may want to reuse for a custom workflow. .. note:: This dataset is not thread safe. Do not use the same instance in multiple threads. See the tutorial of :ref:`data_management_with_scatter_table` for the role of each table column and how values are typically provided. """COLUMNS=["xval","yval","yerr","series_name","series_id","category","shots","analysis",]DTYPES=["Float64","Float64","Float64","string","Int64","string","Int64","string",]def__init__(self):self._lazy_add_rows=[]self._dump=pd.DataFrame(columns=self.COLUMNS)
[docs]@classmethoddeffrom_dataframe(cls,data:pd.DataFrame,)->"ScatterTable":"""Create new dataset with existing dataframe. Args: data: Data dataframe object. Returns: A new ScatterTable instance. """iflist(data.columns)!=cls.COLUMNS:raiseValueError("Input dataframe columns don't match with the ScatterTable spec.")format_data=cls._format_table(data)returncls._create_new_instance(format_data)
@classmethoddef_create_new_instance(cls,data:pd.DataFrame,)->"ScatterTable":# A shortcut for creating instance.# This bypasses data formatting and column compatibility check.# User who calls this method must guarantee the quality of the input data.instance=object.__new__(cls)instance._lazy_add_rows=[]instance._dump=datareturninstance@propertydefdataframe(self):"""Dataframe object of data points."""ifself._lazy_add_rows:# Add data when table element is called.# Adding rows in loop is extremely slow in pandas.tmp_df=pd.DataFrame(self._lazy_add_rows,columns=self.COLUMNS)tmp_df=self._format_table(tmp_df)iflen(self._dump)==0:self._dump=tmp_dfelse:self._dump=pd.concat([self._dump,tmp_df],ignore_index=True)self._lazy_add_rows.clear()returnself._dump@propertydefx(self)->np.ndarray:"""X values."""# For backward compatibility with CurveData.xreturnself.dataframe.xval.to_numpy(dtype=float,na_value=np.nan)@x.setterdefx(self,new_values):self.dataframe.loc[:,"xval"]=new_values
[docs]defxvals(self,series:int|str|None=None,category:str|None=None,analysis:str|None=None,check_unique:bool=True,)->np.ndarray:"""Get subset of X values. A convenient shortcut for getting X data with filtering. Args: series: Identifier of the data series, either integer series index or name. category: Name of data category. analysis: Name of analysis. check_unique: Set True to check if multiple series are contained. When multiple series are contained, it raises a user warning. Returns: Numpy array of X values. """sub_table=self.filter(series,category,analysis)ifcheck_unique:sub_table._warn_composite_data()returnsub_table.x
@propertydefy(self)->np.ndarray:"""Y values."""# For backward compatibility with CurveData.yreturnself.dataframe.yval.to_numpy(dtype=float,na_value=np.nan)@y.setterdefy(self,new_values:np.ndarray):self.dataframe.loc[:,"yval"]=new_values
[docs]defyvals(self,series:int|str|None=None,category:str|None=None,analysis:str|None=None,check_unique:bool=True,)->np.ndarray:"""Get subset of Y values. A convenient shortcut for getting Y data with filtering. Args: series: Identifier of the data series, either integer series index or name. category: Name of data category. analysis: Name of analysis. check_unique: Set True to check if multiple series are contained. When multiple series are contained, it raises a user warning. Returns: Numpy array of Y values. """sub_table=self.filter(series,category,analysis)ifcheck_unique:sub_table._warn_composite_data()returnsub_table.y
@propertydefy_err(self)->np.ndarray:"""Standard deviation of Y values."""# For backward compatibility with CurveData.y_errreturnself.dataframe.yerr.to_numpy(dtype=float,na_value=np.nan)@y_err.setterdefy_err(self,new_values:np.ndarray):self.dataframe.loc[:,"yerr"]=new_values
[docs]defyerrs(self,series:int|str|None=None,category:str|None=None,analysis:str|None=None,check_unique:bool=True,)->np.ndarray:"""Get subset of standard deviation of Y values. A convenient shortcut for getting Y error data with filtering. Args: series: Identifier of the data series, either integer series index or name. category: Name of data category. analysis: Name of analysis. check_unique: Set True to check if multiple series are contained. When multiple series are contained, it raises a user warning. Returns: Numpy array of Y error values. """sub_table=self.filter(series,category,analysis)ifcheck_unique:sub_table._warn_composite_data()returnsub_table.y_err
@propertydefseries_name(self)->np.ndarray:"""Corresponding data name for each data point."""returnself.dataframe.series_name.to_numpy(dtype=object,na_value=None)@series_name.setterdefseries_name(self,new_values:np.ndarray):self.dataframe.loc[:,"series_name"]=new_values@propertydefseries_id(self)->np.ndarray:"""Corresponding data UID for each data point."""returnself.dataframe.series_id.to_numpy(dtype=object,na_value=None)@series_id.setterdefseries_id(self,new_values:np.ndarray):self.dataframe.loc[:,"series_id"]=new_values@propertydefcategory(self)->np.ndarray:"""Array of categories of the data points."""returnself.dataframe.category.to_numpy(dtype=object,na_value=None)@category.setterdefcategory(self,new_values:np.ndarray):self.dataframe.loc[:,"category"]=new_values@propertydefshots(self)->np.ndarray:"""Shot number used to acquire each data point."""returnself.dataframe.shots.to_numpy(dtype=object,na_value=np.nan)@shots.setterdefshots(self,new_values:np.ndarray):self.dataframe.loc[:,"shots"]=new_values@propertydefanalysis(self)->np.ndarray:"""Corresponding analysis name for each data point."""returnself.dataframe.analysis.to_numpy(dtype=object,na_value=None)@analysis.setterdefanalysis(self,new_values:np.ndarray):self.dataframe.loc[:,"analysis"]=new_values
[docs]deffilter(self,series:int|str|None=None,category:str|None=None,analysis:str|None=None,)->ScatterTable:"""Filter data by series, category, and/or analysis name. Args: series: Identifier of the data series, either integer series index or name. category: Name of data category. analysis: Name of analysis. Returns: New ScatterTable object with filtered data. """filt_data=self.dataframeifseriesisnotNone:ifisinstance(series,int):index=filt_data.series_id==serieselifisinstance(series,str):index=filt_data.series_name==serieselse:raiseValueError(f"Invalid series identifier {series}. This must be integer or string.")filt_data=filt_data.loc[index,:]ifcategoryisnotNone:index=filt_data.category==categoryfilt_data=filt_data.loc[index,:]ifanalysisisnotNone:index=filt_data.analysis==analysisfilt_data=filt_data.loc[index,:]returnScatterTable._create_new_instance(filt_data)
[docs]defiter_by_series_id(self)->Iterator[tuple[int,"ScatterTable"]]:"""Iterate over subset of data sorted by the data series index. Yields: Tuple of data series index and subset of ScatterTable. """id_values=self.dataframe.series_idfordidinid_values.dropna().sort_values().unique():yielddid,ScatterTable._create_new_instance(self.dataframe.loc[id_values==did,:])
[docs]defiter_groups(self,*group_by:str,)->Iterator[tuple[tuple[Any,...],"ScatterTable"]]:"""Iterate over the subset sorted by multiple column values. Args: group_by: Names of columns to group by. Yields: Tuple of values for the grouped columns and the corresponding subset of the scatter table. """out=self.dataframetry:values_iter=product(*[out.get(col).unique()forcolingroup_by])exceptAttributeErrorasex:raiseValueError(f"Specified columns don't exist: {group_by} is not a subset of {self.COLUMNS}.")fromexforvaluesinsorted(values_iter):each_matched=[out.get(c)==vforc,vinzip(group_by,values)]all_matched=reduce(lambdax,y:x&y,each_matched)ifnotany(all_matched):continueyieldvalues,ScatterTable._create_new_instance(out.loc[all_matched,:])
[docs]defadd_row(self,xval:float|pd.NA=pd.NA,yval:float|pd.NA=pd.NA,yerr:float|pd.NA=pd.NA,series_name:str|pd.NA=pd.NA,series_id:int|pd.NA=pd.NA,category:str|pd.NA=pd.NA,shots:float|pd.NA=pd.NA,analysis:str|pd.NA=pd.NA,):"""Add new data point to the table. Data must be the same length. Args: xval: X value. yval: Y value. yerr: Standard deviation of y value. series_name: Name of this data series if available. series_id: Index of this data series if available. category: Data category if available. shots: Shot number used to acquire this data point. analysis: Analysis name if available. """self._lazy_add_rows.append([xval,yval,yerr,series_name,series_id,category,shots,analysis])
@classmethoddef_format_table(cls,data:pd.DataFrame)->pd.DataFrame:return(data.replace(np.nan,pd.NA).astype(dict(zip(cls.COLUMNS,cls.DTYPES))).reset_index(drop=True))def_warn_composite_data(self):iflen(self.dataframe.series_name.unique())>1:warnings.warn("Table data contains multiple data series. ""You may want to filter the data by a specific series_id or series_name.",UserWarning,)iflen(self.dataframe.category.unique())>1:warnings.warn("Table data contains multiple categories. ""You may want to filter the data by a specific category name.",UserWarning,)iflen(self.dataframe.analysis.unique())>1:warnings.warn("Table data contains multiple datasets from different component analyses. ""You may want to filter the data by a specific analysis name.",UserWarning,)@property@deprecate_func(since="0.9",additional_msg="Curve data uses dataframe representation. Call .series_id instead.",package_name="qiskit-experiments",is_property=True,)defdata_allocation(self)->np.ndarray:"""Index of corresponding fit model."""returnself.series_id@property@deprecate_func(since="0.9",additional_msg="No alternative is provided. Use .series_name with set operation.",package_name="qiskit-experiments",is_property=True,)deflabels(self)->list[str]:"""List of model names."""# Order sensitivename_id_tups=self.dataframe.groupby(["series_name","series_id"]).groups.keys()return[k[0]forkinsorted(name_id_tups,key=lambdak:k[1])]
[docs]@deprecate_func(since="0.9",additional_msg="Use filter method instead.",package_name="qiskit-experiments",)defget_subset_of(self,index:str|int)->"ScatterTable":"""Filter data by series name or index. Args: index: Series index of name. Returns: A subset of data corresponding to a particular series. """returnself.filter(series=index)
[docs]def__len__(self):"""Return the number of data points stored in the table."""returnlen(self.dataframe)
def__eq__(self,other):returnself.dataframe.equals(other.dataframe)def__json_encode__(self)->dict[str,Any]:return{"class":"ScatterTable","data":self.dataframe.to_dict(orient="index"),}@classmethoddef__json_decode__(cls,value:dict[str,Any])->"ScatterTable":ifnotvalue.get("class",None)=="ScatterTable":raiseValueError("JSON decoded value for ScatterTable is not valid class type.")tmp_df=pd.DataFrame.from_dict(value.get("data",{}),orient="index")returnScatterTable.from_dataframe(tmp_df)