# This code is part of Qiskit.
#
# (C) Copyright IBM 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=method-hidden,too-many-return-statements,c-extension-no-member
"""Experiment serialization methods."""
import base64
import importlib
import inspect
import io
import json
import math
import traceback
import warnings
import zlib
from collections.abc import Callable
from datetime import datetime
from functools import lru_cache
from importlib.metadata import entry_points
from types import FunctionType, MethodType
from typing import Any
import lmfit
import numpy as np
import scipy.sparse as sps
import uncertainties
from qiskit import qpy, quantum_info
from qiskit.circuit import ParameterExpression, QuantumCircuit, Instruction
from qiskit.exceptions import QiskitError
from qiskit.providers import Backend
from qiskit_experiments.version import __version__
# This can be set to true in testing to raise an exception if a class is
# falling through to default serialization
_strict_serialization = False # pylint: disable=invalid-name
# Set of Python packages that are allowed for deserialization. The set is
# loaded and set here once when first needed.
_allowed_packages = None # pylint: disable=invalid-name
@lru_cache
def get_module_version(mod_name: str) -> str:
"""Return the __version__ of a module if defined.
Args:
mod_name: The module to extract the version from.
Returns:
The module version. If the module is `__main__` the
qiskit-experiments version will be returned.
"""
if "." in mod_name:
return get_module_version(mod_name.split(".", maxsplit=1)[0])
# Return qiskit experiments version for classes in this
# module or defined in main
if mod_name in ["qiskit_experiments", "__main__"]:
return __version__
# For other classes attempt to use their module version
# if it is defined
try:
mod = importlib.import_module(mod_name)
return getattr(mod, "__version__", None)
except Exception: # pylint: disable=broad-except
return None
@lru_cache
def get_object_version(obj: Any) -> str:
"""Return the module version of an object, class, or function.
Note that if the object is defined in `__main__` instead
of a module the current qiskit-experiments version will be used.
Args:
obj: A type or function to extract the module version for.
Returns:
The module version for the object. If the object is defined
in `__main__` the qiskit-experiments version will be returned.
"""
if not istype(obj):
return get_object_version(type(obj))
base_mod = obj.__module__.split(".", maxsplit=1)[0]
return get_module_version(base_mod)
def _show_warning(
msg: str | None = None,
traceback_msg: str | None = None,
mod_name: str | None = None,
save_version: str | None = None,
load_version: str | None = None,
):
"""Show warning for partial deserialization"""
warning_msg = f"{msg} " if msg else ""
if mod_name != "__main__":
mod_name = mod_name.split(".", maxsplit=1)[0]
if save_version != load_version:
warning_msg += (
f"\nNOTE: The current version of module '{mod_name}' ({load_version})"
f" differs from the version used for serialization ({save_version})."
)
if traceback_msg:
warning_msg += f"\nThe following exception was raised:\n{traceback_msg}"
warnings.warn(warning_msg, stacklevel=3)
def _serialize_bytes(data: bytes, compress: bool = True) -> dict[str, Any]:
"""Serialize binary data.
Args:
data: Data to be serialized.
compress: Whether to compress the serialized data.
Returns:
The serialized object value as a dict.
"""
if compress:
data = zlib.compress(data)
value = {
"encoded": base64.standard_b64encode(data).decode("utf-8"),
"compressed": compress,
}
return {"__type__": "b64encoded", "__value__": value}
def _deserialize_bytes(value: dict) -> str:
"""Deserialize binary encoded data.
Args:
value: value to be deserialized.
Returns:
Deserialized string representation.
Raises:
ValueError: If encoded data cannot be deserialized.
"""
try:
encoded = value["encoded"]
compressed = value["compressed"]
decoded = base64.standard_b64decode(encoded)
if compressed:
decoded = zlib.decompress(decoded)
return decoded
except Exception as ex: # pylint: disable=broad-except
raise ValueError("Could not deserialize binary encoded data.") from ex
def _serialize_and_encode(
data: Any, serializer: Callable, compress: bool = True, **kwargs: Any
) -> str:
"""Serialize the input data and return the encoded string.
Args:
data: Data to be serialized.
serializer: Function used to serialize data.
compress: Whether to compress the serialized data.
kwargs: Keyword arguments to pass to the serializer.
Returns:
String representation.
"""
with io.BytesIO() as buff:
serializer(buff, data, **kwargs)
buff.seek(0)
serialized_data = buff.read()
return _serialize_bytes(serialized_data, compress=compress)
def _decode_and_deserialize(value: dict, deserializer: Callable, name: str | None = None) -> Any:
"""Decode and deserialize input data.
Args:
value: The binary encoded serialized data value.
deserializer: Function used to deserialize data.
name: Object type name for warning message if deserialization fails.
Returns:
Deserialized data.
Raises:
ValueError: If deserialization fails.
"""
try:
with io.BytesIO() as buff:
buff.write(value)
buff.seek(0)
orig = deserializer(buff)
return orig
except Exception as ex: # pylint: disable=broad-except
raise ValueError(f"Could not deserialize <{name}> data.") from ex
def _serialize_safe_float(obj: Any):
"""Recursively serialize basic types safely handing inf and NaN"""
if isinstance(obj, float):
if math.isfinite(obj):
return obj
else:
value = obj
if math.isnan(obj):
value = "NaN"
elif obj == math.inf:
value = "Infinity"
elif obj == -math.inf:
value = "-Infinity"
return {"__type__": "safe_float", "__value__": value}
elif isinstance(obj, (list, tuple)):
return [_serialize_safe_float(i) for i in obj]
elif isinstance(obj, dict):
return {key: _serialize_safe_float(val) for key, val in obj.items()}
elif isinstance(obj, complex):
return {"__type__": "complex", "__value__": _serialize_safe_float([obj.real, obj.imag])}
return obj
def istype(obj: Any) -> bool:
"""Return True if object is a class, function, or method type"""
return inspect.isclass(obj) or inspect.isfunction(obj) or inspect.ismethod(obj)
def _serialize_type(type_name: type | FunctionType | MethodType):
"""Serialize a type, function, or class method"""
mod = type_name.__module__
value = {
"name": type_name.__qualname__,
"module": mod,
"version": get_module_version(mod),
}
return {"__type__": "type", "__value__": value}
def _load_allowed_packages():
global _allowed_packages # pylint: disable=global-statement
if _allowed_packages is None:
ep_pkgs = [
e.module for e in entry_points(group="qiskit_experiments.deserialization_packages")
]
_allowed_packages = frozenset(ep_pkgs)
def _check_quantum_info_class(cls) -> bool:
"""Check cls is a quantum_info type to be deserialized"""
mod = getattr(cls, "__module__", "")
settings = getattr(cls, "settings", None)
# Class comes from qiskit.quantum_info and has a settings property
return (
mod.startswith("qiskit.quantum_info")
and hasattr(quantum_info, cls.__name__)
and isinstance(settings, property)
)
def _load_quantum_info_type(name: str, module: str) -> Any:
"""Attempt to load a type from the qiskit quantum_info package"""
cls = getattr(quantum_info, name, None)
if cls is None:
raise QiskitError(f"Could not load type {name} from module {module}!")
if not hasattr(cls, "__module__"):
raise QiskitError(
f"'{name}' specified to load from {module} does not appear to be a class!"
)
if not cls.__module__.startswith("qiskit.quantum_info"):
raise QiskitError(
f"Could not load type {name} from module {module}. It appears to come "
"from {cls.__module__} instead."
)
settings = getattr(cls, "settings", None)
if settings is None or not isinstance(settings, property):
raise QiskitError(
f'Class {name} from qiskit.quantum_info does not have a "settings" property. '
f'Only class from qiskit.quantum_info with a "settings" property can be loaded.'
)
return cls
def _deserialize_ufloat(value: dict[str, Any]) -> uncertainties.UFloat:
settings = value.get("settings", {})
if "value" in settings:
return uncertainties.ufloat(settings["value"], settings.get("std_dev"))
raise QiskitError(f"Bad ufloat settings: {settings}")
def _deserialize_type(value: dict):
"""Deserialize a Python type"""
traceback_msg = None
load_version = None
if "." in value["name"]:
raise QiskitError(f"Deserializing class members is no longer supported: {value}")
try:
name = value["name"]
mod = value["module"]
# These two conditionals handle previously serialized data before
# dedicated ufloat and qiskit.quantum_info serializers were added.
#
# Perhaps they can be removed in the future
if mod == "uncertainties.core" and name == "Variable":
return uncertainties.ufloat
if mod.startswith("qiskit.quantum_info"):
return _load_quantum_info_type(value["name"], value["module"])
_load_allowed_packages()
package = mod.partition(".")[0]
if package not in _allowed_packages:
raise QiskitError(
f"Import of {package} denied. It must be registered with the "
"'qiskit_experiments.deserialization_packages' entry point to "
"allow loading objects from it. See the documentation for "
"'qiskit_experiments.framework.ExperimentEncoder'."
)
scope = importlib.import_module(mod)
if not hasattr(scope, name):
raise QiskitError(f"Requested object '{name}' not foudn in '{mod}'!")
obj = getattr(scope, name)
if not inspect.isclass(obj):
raise QiskitError(f"Requested object '{name}' of '{mod}' is not a class!")
if obj.__module__.partition(".")[0] != package:
raise QiskitError(
f"Object '{name}' of '{mod}' appears to come from {obj.__module__} instead!"
)
return obj
except Exception as ex: # pylint: disable=broad-except
traceback_msg = "".join(traceback.format_exception(type(ex), ex, ex.__traceback__))
# Show warning
warning_msg = f"Cannot deserialize {name}. The type could not be found in module {mod}."
save_version = value.get("version", None)
load_version = get_module_version(mod)
_show_warning(
warning_msg,
traceback_msg=traceback_msg,
mod_name=mod,
save_version=save_version,
load_version=load_version,
)
# Return partially deserialized value
return value
def _serialize_object(obj: Any) -> dict:
"""Serialize a class instance from its init args and kwargs.
Args:
obj: The object to be serialized.
Returns:
Dict serialized class instance.
"""
if hasattr(obj, "__json_encode__"):
settings = obj.__json_encode__()
has_json_encode = True
else:
settings = {}
has_json_encode = False
settings = _serialize_safe_float(settings)
cls = type(obj)
value = {
"class": _serialize_type(cls),
"settings": settings,
"version": get_object_version(cls),
}
if _strict_serialization and not has_json_encode:
# We do not expect to use _serialize_object except for cases where
# __json_encode__ is defined
raise ValueError(f"Unexpected default serialization for {value}")
return {"__type__": "object", "__value__": value}
def _deserialize_object(value: dict) -> Any:
"""Deserialize class instance saved as settings"""
cls = value.get("class", {})
if isinstance(cls, dict):
# Deserialization of class type failed.
return value
settings = value.get("settings", {})
if hasattr(cls, "__json_decode__"):
try:
return cls.__json_decode__(settings)
except Exception as ex: # pylint: disable=broad-except
traceback_msg = "".join(traceback.format_exception(type(ex), ex, ex.__traceback__))
warning_msg = (
f"Could not deserialize instance of class {cls} from value {settings} "
"using __json_decode__ method."
)
else:
traceback_msg = None
warning_msg = (
f"Could not deserialize instance of class {cls} from settings {settings}. "
f"{cls}.__json_decode__ does not exist."
)
# Display warning msg if deserialization failed
mod_name = cls.__module__
load_version = get_object_version(cls)
save_version = value.get("version")
_show_warning(
warning_msg,
traceback_msg=traceback_msg,
mod_name=mod_name,
save_version=save_version,
load_version=load_version,
)
# Return partially deserialized value
return value
[docs]
class ExperimentEncoder(json.JSONEncoder):
"""JSON Encoder for Qiskit Experiments.
.. warning::
It is recommended only to deserialize data with
:class:`ExperimentDecoder` from trusted sources. For custom classes,
the deserialization procedure involves dynamic execution of code based
on the content of the serialized data. The deserialization code
includes some safeguards:
1. Only modules registered with the
``qiskit_experiments.deserialization_packages`` entry point are
imported dynamically.
2. Only classes (as determined by Python's ``inspect.isclass``
function) are referenced from the imported modules for further
processing.
3. These classes are checked to ensure that they were
defined by the registered modules they were loaded from.
4. For the referenced classes, only the ``__json_decode__`` method
is called with the serialized data.
Even with these safeguards, loading a payload involves instantiating
registered classes with arbitrary inputs. These classes were not
written assuming malicious input.
Note that versions of Qiskit Experiments older than 0.14 could load
arbitrary functions like ``subprocess.run`` and pass them data from
the deserialization payload.
This class extends the default Python JSONEncoder by including built-in
support for
* complex numbers, inf and NaN floats, sets, and dataclasses.
* NumPy ndarrays and SciPy sparse matrices.
* Qiskit ``QuantumCircuit``.
* Any class that implements a ``__json_encode__`` method
Custom classes can be serialized by this encoder by implementing a
``__json_encode__`` method. The serialization procedure is as follows:
The ``__json_encode__`` method should have signature
.. code-block:: python
def __json_encode__(self) -> Any:
# return a JSON serializable object value
The value returned by ``__json_encode__`` must be an object that can be
serialized by the JSON encoder (for example a ``dict`` containing
other JSON serializable objects).
To deserialize this object using the :class:`ExperimentDecoder` the
class must also provide a ``__json_decode__`` class method that can
convert the value returned by ``__json_encode__`` back to the object.
This method should have signature
.. code-block:: python
@classmethod
def __json_decode__(cls, value: Any) -> Self:
# recover the object from the `value` returned by __json_encode__
Additionally, the custom class's package metadata must register the top
level import package as a ``qiskit_experiments.def`` Python entry point. In
``pyproject.toml`` the entry point registration would like this:
.. code-block:: toml
[project.entry-points."qiskit_experiments.deserialization_packages"]
custom-package-name = "custom_package"
where ``custom_package`` is the Python import module that the custom class
is below (the import path before the first ``.``; ``custom_package`` for
class ``MyClass`` if it is normally imported as ``from
custom_package.subpackage import MyClass`` for example). The entry point
name, ``custom-package-name`` in the example, is not used and can be set to
any descriptive name.
If the object has no ``__json_encode__`` method and all other special cases
(numpy arrays, Qiskit quantum info classes, etc.) do not apply, the
object is serialized as though ``__json_encode__`` returned an empty dict.
Without a ``__json_decode__`` method, the object will be loaded by
:class:`ExperimentDecoder` as a dictionary containing the name of the
object's class and module. This incomplete loading of the object may lead
to other code execution problems.
.. note::
Serialization of custom classes works for user-defined classes in
Python scripts, notebooks, or third party modules. Note however
that these will only be able to be de-serialized if that class
can be imported form the same scope at the time the
:class:`ExperimentDecoder` is invoked. For scripts and notebook, the
scope is named ``__main__`` which is registered by default with the
``qiskit_experiments.deserialization_packages`` entry point.
"""
[docs]
def default(self, obj: Any) -> Any: # pylint: disable=arguments-renamed
if istype(obj):
return _serialize_type(obj)
if hasattr(obj, "__json_encode__"):
return _serialize_object(obj)
if isinstance(obj, complex):
return _serialize_safe_float(obj)
if isinstance(obj, set):
return {"__type__": "set", "__value__": list(obj)}
if isinstance(obj, np.ndarray):
value = _serialize_and_encode(obj, np.save, allow_pickle=False)
return {"__type__": "ndarray", "__value__": value}
if isinstance(obj, sps.spmatrix):
value = _serialize_and_encode(obj, sps.save_npz, compress=False)
return {"__type__": "spmatrix", "__value__": value}
if isinstance(obj, bytes):
return _serialize_bytes(obj)
if isinstance(obj, datetime):
return {"__type__": "datetime", "__value__": obj.isoformat()}
if isinstance(obj, np.number):
return obj.item()
if isinstance(obj, uncertainties.UFloat):
# This could be UFloat (AffineScalarFunc) or Variable.
# UFloat is a base class of Variable that contains parameter correlation.
# i.e. Variable is special subclass for single number.
# Since this object is not serializable, we will drop correlation information
# during serialization. Then both can be serialized as Variable.
# Note that UFloat doesn't have a tag.
return {
"__type__": "ufloat",
"__value__": {
"value": _serialize_safe_float(obj.nominal_value),
"std_dev": _serialize_safe_float(obj.std_dev),
},
}
if isinstance(obj, lmfit.Model):
# LMFIT Model object. Delegate serialization to LMFIT.
return {
"__type__": "LMFIT.Model",
"__value__": obj.dumps(),
}
if isinstance(obj, Instruction):
# Serialize gate by storing it in a circuit.
circuit = QuantumCircuit(obj.num_qubits, obj.num_clbits)
circuit.append(obj, range(obj.num_qubits), range(obj.num_clbits))
value = _serialize_and_encode(
data=circuit, serializer=lambda buff, data: qpy.dump(data, buff)
)
return {"__type__": "Instruction", "__value__": value}
if isinstance(obj, QuantumCircuit):
value = _serialize_and_encode(
data=obj, serializer=lambda buff, data: qpy.dump(data, buff)
)
return {"__type__": "QuantumCircuit", "__value__": value}
if isinstance(obj, ParameterExpression):
value = _serialize_and_encode(
data=obj,
serializer=qpy._write_parameter_expression,
compress=False,
)
return {"__type__": "ParameterExpression", "__value__": value}
if _check_quantum_info_class(obj.__class__):
return {
"__type__": "qiskit.quantum_info",
"__value__": {
"class": obj.__class__.__name__,
"settings": _serialize_safe_float(obj.settings),
},
}
if isinstance(obj, Backend):
return None
try:
return super().default(obj)
except TypeError:
return _serialize_object(obj)
[docs]
class ExperimentDecoder(json.JSONDecoder):
"""JSON Decoder for Qiskit Experiments.
.. warning::
It is recommended to use this class only on trusted data. See the
warning in the :class:`ExperimentEncoder` documentation for more
details.
This class extends the default Python JSONDecoder by including built-in
support for all objects that that can be serialized using the
:class:`ExperimentEncoder`.
See :class:`ExperimentEncoder` class documentation for further details.
"""
_NaNs = {"NaN": math.nan, "Infinity": math.inf, "-Infinity": -math.inf}
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)
[docs]
def object_hook(self, obj):
"""Object hook."""
if "__type__" in obj:
obj_type = obj["__type__"]
obj_val = obj["__value__"]
if obj_type == "complex":
return obj_val[0] + 1j * obj_val[1]
if obj_type == "ndarray":
return _decode_and_deserialize(obj_val, np.load, name=obj_type)
if obj_type == "spmatrix":
return _decode_and_deserialize(obj_val, sps.load_npz, name=obj_type)
if obj_type == "b64encoded":
return _deserialize_bytes(obj_val)
if obj_type == "set":
return set(obj_val)
if obj_type == "datetime":
return datetime.fromisoformat(obj_val)
if obj_type == "ufloat":
return uncertainties.ufloat(obj_val["value"], obj_val["std_dev"])
if obj_type == "LMFIT.Model":
tmp = lmfit.Model(func=None)
load_obj = tmp.loads(s=obj_val)
return load_obj
if obj_type == "Instruction":
circuit = _decode_and_deserialize(obj_val, qpy.load, name="QuantumCircuit")[0]
return circuit.data[0].operation
if obj_type == "QuantumCircuit":
return _decode_and_deserialize(obj_val, qpy.load, name=obj_type)[0]
if obj_type == "ParameterExpression":
return _decode_and_deserialize(
obj_val, qpy._read_parameter_expression, name=obj_type
)
if obj_type == "qiskit.quantum_info" and hasattr(quantum_info, obj_val["class"]):
cls = getattr(quantum_info, obj_val["class"])
return cls(**obj_val["settings"])
if obj_type == "safe_float":
return self._NaNs.get(obj_val, obj_val)
if _check_quantum_info_class(obj_val.get("class")):
return obj_val["class"](**obj_val["settings"])
if obj_val.get("class") is uncertainties.ufloat:
return _deserialize_ufloat(obj_val)
if obj_type == "object":
return _deserialize_object(obj_val)
if obj_type == "type":
return _deserialize_type(obj_val)
return obj