# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2025.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Qiskit Machine Learning estimator primitive.
This module provides a small wrapper around Qiskit's ``StatevectorEstimator``
that offers switch between:
* Exact mode (``default_precision == 0``): analytic expectation values with
deterministic outputs and zero standard deviation.
* Delegate mode (``default_precision != 0``): defer execution to
``StatevectorEstimator`` (precision-aware reference implementation).
"""
from __future__ import annotations
from collections.abc import Iterable, Mapping
from dataclasses import asdict, is_dataclass
from types import SimpleNamespace
from typing import Any
import numpy as np
from qiskit.primitives import StatevectorEstimator
from qiskit.primitives.containers import (
DataBin,
EstimatorPubLike,
PrimitiveResult,
PubResult,
)
from qiskit.primitives.containers.estimator_pub import EstimatorPub
from qiskit.primitives.primitive_job import PrimitiveJob
from qiskit.quantum_info import Operator, SparsePauliOp, Statevector
[docs]
class QMLEstimator(StatevectorEstimator):
"""V2-based estimator primitive with two modes.
Modes are selected at construction time:
* ``default_precision == 0.0`` (default): exact mode: Results are
deterministic (analytic expectation values) with ``stds == 0``.
Any per-call ``precision`` override is accepted for API compatibility
but ignored.
* ``default_precision > 0.0``: delegate mode: Execution is delegated
to :class:`~qiskit.primitives.StatevectorEstimator`, which interprets the
precision parameter according to the reference primitive behavior.
"""
def __init__(
self,
*,
default_precision: float = 0.0,
seed: np.random.Generator | int | None = None,
**kwargs: Any,
) -> None:
if float(default_precision) == 0.0:
self._exact_mode = True
else:
self._exact_mode = False
if self._exact_mode:
super().__init__(default_precision=0.0, seed=seed, **kwargs)
else:
super().__init__(default_precision=float(default_precision), seed=seed, **kwargs)
# Provide a mutable, V1-style `options` name space for ML integrations.
parent_opts = object.__getattribute__(self, "__dict__").get("options", None)
base = _options_to_dict(parent_opts)
merged = dict(base)
merged.setdefault("default_precision", float(default_precision))
merged.setdefault("seed", seed)
self.options = _OptionsNS(**merged)
[docs]
def run(
self,
pubs: Iterable[EstimatorPubLike],
*,
precision: float | None = None,
) -> PrimitiveJob[PrimitiveResult[PubResult]]:
"""Evaluate a collection of estimator PUBs.
Args:
pubs: Iterable of PUB-like inputs describing circuits, observables,
and parameter values.
precision: Target precision for V2-style estimation. In exact mode,
this value is ignored and results are deterministic.
Returns:
A job that yields a ``PrimitiveResult[PubResult]``.
"""
if not self._exact_mode:
return super().run(pubs, precision=precision)
coerced = [EstimatorPub.coerce(pub, 0.0) for pub in pubs] # satisfy validation
job: PrimitiveJob[PrimitiveResult[PubResult]] = PrimitiveJob(self._run_exact, coerced)
job._submit() # pylint: disable=protected-access
return job
# -------------------- exact-mode implementation --------------------
def _run_exact(self, pubs: list[EstimatorPub]) -> PrimitiveResult[PubResult]:
return PrimitiveResult(
[self._run_pub_exact(pub) for pub in pubs],
metadata={"version": 2},
)
def _run_pub_exact(self, pub: EstimatorPub) -> PubResult:
circuit = pub.circuit
observables = pub.observables
parameter_values = pub.parameter_values
bound_circuits = parameter_values.bind_all(circuit)
bc_circuits, bc_obs = np.broadcast_arrays(bound_circuits, observables)
evs = np.empty(bc_circuits.shape, dtype=np.float64)
stds = np.zeros(bc_circuits.shape, dtype=np.float64)
for idx in np.ndindex(bc_circuits.shape):
sv = Statevector.from_instruction(bc_circuits[idx])
obs = _coerce_observable(bc_obs[idx])
mean = sv.expectation_value(obs)
evs[idx] = float(np.real_if_close(mean))
data = DataBin(evs=evs, stds=stds, shape=evs.shape)
meta = {
"shots": None,
"target_precision": 0.0,
"circuit_metadata": getattr(pub, "metadata", {}),
"exact": True,
}
return PubResult(data=data, metadata=meta)
# pylint: disable=too-many-return-statements
def _coerce_observable(obs: Any) -> Any:
"""Normalize supported observable formats.
Converts common encodings into objects accepted by
:meth:`qiskit.quantum_info.Statevector.expectation_value`.
"""
if isinstance(obs, (SparsePauliOp, Operator)):
return obs
prim = getattr(obs, "primitive", None)
if isinstance(prim, SparsePauliOp):
return prim
if isinstance(obs, Mapping):
if not obs:
raise ValueError("Observable mapping is empty.")
return SparsePauliOp.from_list(
[(str(label), complex(coeff)) for label, coeff in obs.items()]
)
if isinstance(obs, str):
return SparsePauliOp.from_list([(obs, 1.0)])
if (
isinstance(obs, (list, tuple))
and obs
and isinstance(obs[0], (list, tuple))
and len(obs[0]) == 2
):
return SparsePauliOp.from_list([(str(lbl), complex(c)) for (lbl, c) in obs])
try:
return SparsePauliOp(obs)
except Exception: # pylint: disable=broad-exception-caught
return Operator(obs)
def _options_to_dict(opts: Any) -> dict[str, Any]:
"""Best-effort conversion of an options-like object into a plain dict."""
if opts is None:
return {}
if is_dataclass(opts):
return dict(asdict(opts)) # type: ignore
if isinstance(opts, Mapping):
return dict(opts)
to_dict = getattr(opts, "to_dict", None)
if callable(to_dict):
try:
return dict(to_dict())
except Exception: # pylint: disable=broad-exception-caught
pass
try:
return dict(vars(opts))
except TypeError:
return {}
class _OptionsNS(SimpleNamespace):
"""Mutable options name space supporting ``update(**kwargs)``."""
def update(self, **kwargs: Any) -> None:
"""Updates options"""
for k, v in kwargs.items():
setattr(self, k, v)