# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2025.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Qiskit Machine Learning Sampler"""
from __future__ import annotations
from dataclasses import dataclass, is_dataclass, asdict
from typing import Iterable, Mapping, Any, overload, cast
from types import SimpleNamespace
from numpy.typing import NDArray
import numpy as np
from qiskit.circuit import ClassicalRegister, QuantumCircuit
from qiskit.quantum_info import Statevector
from qiskit.primitives import (
StatevectorSampler,
DataBin,
PrimitiveJob,
PrimitiveResult,
SamplerPubLike,
SamplerPubResult,
)
from qiskit.primitives.containers.sampler_pub import SamplerPub
[docs]
class QMLSampler(StatevectorSampler):
"""
V2-based sampler primitive with two modes.
- shots=None (default): exact mode, no sampling. Returns deterministic probabilities.
- shots=int : sampling mode, delegate to StatevectorSampler with given default_shots.
"""
def __init__(self, *, shots: int | None = None, **kwargs):
"""Statevector-based sampler supporting exact (analytic) and sampling modes.
Args:
shots (int | None): Number of shots for sampling mode. If ``None``, run in exact mode.
**kwargs: Additional arguments forwarded to StatevectorSampler.
Returns:
QMLSampler: Configured sampler instance.
"""
self._exact_mode = shots is None
if self._exact_mode:
super().__init__(**kwargs)
else:
super().__init__(default_shots=int(shots), **kwargs)
parent_opts = object.__getattribute__(self, "__dict__").get("options", None)
base = _options_to_dict(parent_opts)
merged = dict(base)
merged.setdefault("default_shots", shots)
self.options = _OptionsNS(**merged)
[docs]
def run(
self,
pubs: Iterable[SamplerPubLike],
*,
shots: int | None = None,
) -> PrimitiveJob[PrimitiveResult[SamplerPubResult]]:
"""Run the sampler on PUBs.
Args:
pubs (Iterable[SamplerPubLike]): Publications to evaluate.
shots (int | None): Optional override for number of shots.
Returns:
PrimitiveJob[PrimitiveResult[SamplerPubResult]]: Job executing the sampler.
"""
if not self._exact_mode:
return super().run(pubs, shots=shots)
# Exact mode: compute probabilities from statevector, no sampling.
coerced = [SamplerPub.coerce(pub, shots=1) for pub in pubs] # satisfy validation
job = PrimitiveJob(self._run_exact, coerced)
job._submit()
return job
# -------------------- exact evaluation --------------------
def _run_exact(self, pubs: Iterable[SamplerPub]) -> PrimitiveResult[SamplerPubResult]:
"""Deterministically evaluate all PUBs.
Args:
pubs (Iterable[SamplerPub]): Fully coerced PUBs.
Returns:
PrimitiveResult[SamplerPubResult]: Exact results for each PUB.
"""
results = [self._run_pub_exact(pub) for pub in pubs]
return PrimitiveResult(results)
def _run_pub_exact(self, pub: SamplerPub) -> SamplerPubResult:
"""Compute per-register exact probability containers for a single PUB.
Args:
pub (SamplerPub): PUB containing circuit and parameters.
Returns:
SamplerPubResult: Exact probability results for this PUB.
"""
unitary_circ, qargs, meas_info = _preprocess_circuit(pub.circuit)
bound_circuits = pub.parameter_values.bind_all(unitary_circ)
# For each bound config, compute exact joint probabilities over measured qubits.
joint_probs_per_index = np.empty(bound_circuits.shape, dtype=object)
for index, circ in np.ndenumerate(bound_circuits):
if qargs:
sv = Statevector.from_instruction(circ)
joint = sv.probabilities_dict(qargs=qargs)
else:
joint = {"": 1.0}
joint_probs_per_index[index] = joint
# Build per-register ExactProbArray views (one per broadcast index)
data_fields: dict[str, Any] = {}
names: list[str] = []
for item in meas_info:
names.append(item.creg_name)
arr = np.empty(bound_circuits.shape, dtype=object)
for index, joint in np.ndenumerate(joint_probs_per_index):
arr[index] = ExactProbArray(
joint_probs=joint,
mask=list(item.qreg_indices),
num_bits=item.num_bits,
shape=(),
)
# Wrap ND arrays so users can call .get_counts() / .get_probabilities()
field_value: Any
if arr.shape == ():
field_value = arr.item()
else:
field_value = ExactProbNDArray(arr)
data_fields[item.creg_name] = field_value
# Package DataBin and return our result subclass.
data_bin = DataBin(**data_fields, shape=bound_circuits.shape)
return _ExactSamplerPubResult(
data_bin,
metadata={
"shots": None,
"exact": True,
"names": names,
"circuit_metadata": getattr(pub, "metadata", {}),
},
)
# -------------------- deterministic probability containers --------------------
class ExactProbArray:
"""
Deterministic probability container (scalar, i.e. shape == ()).
Methods:
- get_probabilities(loc=None) -> dict[str, float]
- get_counts(loc=None, shots=None) -> dict[str, int] # only if distribution is dyadic
Supports concatenation via concatenate_bits() so join_data() forms the exact joint.
"""
__slots__ = ("_joint_probs", "_mask", "_num_bits", "_shape")
def __init__(
self,
joint_probs: Mapping[str, float], # over the full measured bitstring
mask: list[int], # LSB-based indices this register exposes
num_bits: int,
shape: tuple[int, ...] = (),
):
"""Exact probability container for a single classical register.
Args:
joint_probs (Mapping[str, float]): Full joint measured-bit distribution.
mask (list[int]): LSB-ordered indices selecting bits exposed by this register.
num_bits (int): Width of the classical register.
shape (tuple[int, ...]): Broadcast shape (default ()).
"""
self._joint_probs = dict(joint_probs)
self._mask = list(mask)
self._num_bits = int(num_bits)
self._shape = tuple(shape)
@property
def shape(self) -> tuple[int, ...]:
"""Return the broadcast shape of this container."""
return self._shape
@property
def num_bits(self) -> int:
"""Return the number of classical bits represented by this container."""
return self._num_bits
@property
def num_shots(self):
"""Return None to indicate that this distribution is analytic, not sampled."""
return None
def _project_joint_to_mask(self, probs: Mapping[str, float]) -> dict[str, float]:
"""Project the joint distribution to this register's bit mask.
Args:
probs (Mapping[str, float]): Full joint distribution.
Returns:
dict[str, float]: Marginalized probability distribution.
"""
out: dict[str, float] = {}
for bitstr, p in probs.items():
bits = list(bitstr) # left
sel = [bits[-1 - i] for i in reversed(self._mask)] # LSB index 0 is rightmost char
key = "".join(sel)
out[key] = out.get(key, 0.0) + p
return out
def get_probabilities(self) -> dict[str, float]:
"""Return exact bitstring probabilities.
Returns:
dict[str, float]: Map from bitstring to exact probability.
"""
return self._project_joint_to_mask(self._joint_probs)
def get_counts(self, shots: int | None = None) -> dict[str, int]:
"""Return dyadic counts consistent with probabilities.
Args:
shots (int | None): Number of counts to generate. If ``None``, use dyadic size.
Returns:
dict[str, int]: Counts per bitstring.
Raises:
ValueError: If the distribution is not dyadic.
"""
probs = self.get_probabilities()
def dyadic_k(p: float, tol=1e-12, kmax=60) -> int | None:
"""Helper function"""
if p in (0.0, 1.0):
return 0
for k in range(kmax + 1):
m = round(p * (1 << k))
if abs(p - m / float(1 << k)) <= tol:
return k
return None
ks = []
for p in probs.values():
k = dyadic_k(p)
if k is None:
raise ValueError(
"ExactProbArray.get_counts: distribution is not dyadic; "
"use get_probabilities() for exact values."
)
ks.append(k)
k_common = max(ks) if ks else 0
M = (1 << k_common) if shots is None else int(shots) # pylint: disable=invalid-name
counts: dict[str, int] = {k: int(round(v * M)) for k, v in probs.items()}
total = sum(counts.values())
if shots is None and counts and total != M:
# Adjust the most likely entry to make totals consistent.
key_star = max(probs, key=probs.get)
counts[key_star] += M - total
return counts
@staticmethod
def concatenate_bits(items: list["ExactProbArray"]) -> "ExactProbArray":
"""Concatenate multiple ``ExactProbArray`` instances.
Args:
items (list[ExactProbArray]): Containers to concatenate.
Returns:
ExactProbArray: Wider register combining all bits.
Raises:
ValueError: If the joint distributions are incompatible.
"""
if not items:
raise ValueError("No containers to concatenate.")
joint = items[0]._joint_probs
for it in items[1:]:
if it._joint_probs is not joint and it._joint_probs != joint:
raise ValueError("Cannot join different joint distributions.")
mask: list[int] = []
for it in items:
mask.extend(it._mask)
num_bits = sum(it._num_bits for it in items)
return ExactProbArray(joint, mask=mask, num_bits=num_bits, shape=items[0]._shape)
class ExactProbNDArray:
"""
ND wrapper around a numpy ndarray of ExactProbArray (dtype=object).
Exposes SamplerV2-like methods on the whole array:
- .get_counts(loc=None, shots=None)
- .get_probabilities(loc=None)
Supports indexing with numpy semantics: obj[idx].
"""
__slots__ = ("_arr",)
def __init__(self, arr: NDArray[np.object_]):
"""N-dimensional wrapper for arrays of ``ExactProbArray``.
Args:
arr (np.ndarray): Object array of ``ExactProbArray`` elements.
"""
self._arr: NDArray[np.object_] = arr
# --- array-like protocol ---
@property
def shape(self) -> tuple[int, ...]:
"""Return the shape of the underlying array."""
return self._arr.shape
@overload
def __getitem__(self, idx: int | tuple[int, ...]) -> ExactProbArray: ...
@overload
def __getitem__(self, idx: slice | tuple[Any, ...]) -> "ExactProbNDArray": ...
def __getitem__(self, idx):
"""Return probabilities element-wise or at a specific index.
Args:
loc (int | tuple[int, ...] | None): Optional index.
Returns:
dict[str, float] | np.ndarray: Probabilities for the selected element or array.
"""
out = self._arr[idx]
# Preserve behavior: if slicing returns an ndarray of ExactProbArray, wrap again.
if isinstance(out, np.ndarray):
return ExactProbNDArray(out)
return out # single ExactProbArray
# Optional, used by some user code
@property
def num_shots(self):
"""Return None to indicate that all elements represent analytic distributions."""
return None
@property
def num_bits(self) -> int:
"""Return the number of bits per element, inferred from a representative
ExactProbArray."""
for raw in self._arr.flat:
if isinstance(raw, ExactProbArray):
rep = cast(ExactProbArray, raw)
return rep.num_bits
return 0
# --- Sampler-style methods ---
@overload
def get_probabilities(self, loc: int | tuple[int, ...]) -> dict[str, float]: ...
@overload
def get_probabilities(self, loc: None = None) -> NDArray[np.object_]: ...
def get_probabilities(self, loc: int | tuple[int, ...] | None = None):
"""Return probabilities for a single location or an array of probability dicts for
all entries."""
if loc is not None:
elem = cast(ExactProbArray, self._arr[loc])
return elem.get_probabilities()
out: NDArray[np.object_] = np.empty(self._arr.shape, dtype=object)
for idx in np.ndindex(self._arr.shape):
elem = cast(ExactProbArray, self._arr[idx])
out[idx] = elem.get_probabilities()
return out
@overload
def get_counts(
self, loc: int | tuple[int, ...], shots: int | None = None
) -> dict[str, int]: ...
@overload
def get_counts(self, loc: None = None, shots: int | None = None) -> dict[str, int]: ...
def get_counts(self, loc: int | tuple[int, ...] | None = None, shots: int | None = None):
"""Return counts element-wise or the union across positions.
When ``location=None``, follow ``BitArray`` semantics: union counts across all positions.
If you want per-position, index first (e.g., ``obj[i].get_counts())``.
Args:
loc (int | tuple[int, ...] | None): Optional index.
shots (int | None): Number of shots for counts.
Returns:
dict[str, int] | np.ndarray: Counts for the selected element or union.
"""
if loc is not None:
elem = cast(ExactProbArray, self._arr[loc])
return elem.get_counts(shots=shots)
total: dict[str, int] = {}
# for exact non-dyadic distributions this raises; caller can use get_probabilities instead
for raw in self._arr.flat:
elem = cast(ExactProbArray, raw)
cnt = elem.get_counts(shots=shots)
for k, v in cnt.items():
total[k] = total.get(k, 0) + v
return total
# --- helpers -------------------------------------------------
def _options_to_dict(opts) -> dict:
"""Convert an options object to a plain dict.
Args:
opts: Any options-like object.
Returns:
dict: Extracted key–value pairs.
"""
if opts is None:
return {}
if is_dataclass(opts):
return asdict(opts) # type: ignore
if hasattr(opts, "__dict__"):
return {k: v for k, v in vars(opts).items() if not k.startswith("_")}
# Fallback: probe attributes
d = {}
for k in dir(opts):
if k.startswith("_"):
continue
v = getattr(opts, k)
if callable(v):
continue
d[k] = v
return d
class _OptionsNS(SimpleNamespace):
"""Mutable, dict-like options name space with an update(**kwargs) helper."""
def update(self, **kwargs):
"""Update options in place.
Args:
**kwargs: Key–value pairs to update.
"""
for k, v in kwargs.items():
setattr(self, k, v)
# ---------------- measurement mapping from StatevectorSampler --------------------
@dataclass
class _MeasureInfo:
"""Return a map from each final classical bit to the qubit index it measures, assuming
only final measurements."""
creg_name: str
num_bits: int # measured bit-width of this register
qreg_indices: list[int] # LSB-order indices into the joint measured-qubit list
def _final_measurement_mapping(circuit: QuantumCircuit) -> dict[tuple[ClassicalRegister, int], int]:
"""Map final classical bits to qubit indices.
Args:
circuit (QuantumCircuit): Circuit with final measurements.
Returns:
dict[(ClassicalRegister, int), int]: Mapping from classical bit to qubit index.
"""
active_qubits = set(range(circuit.num_qubits))
active_cbits = set(range(circuit.num_clbits))
mapping: dict[tuple[ClassicalRegister, int], int] = {}
for inst in circuit[::-1]:
op = inst.operation.name
if op == "measure":
loc = circuit.find_bit(inst.clbits[0])
c_idx = loc.index
q_idx = circuit.find_bit(inst.qubits[0]).index
if c_idx in active_cbits and q_idx in active_qubits:
for creg in loc.registers: # (ClassicalRegister, offset within that register)
mapping[creg] = q_idx
active_cbits.remove(c_idx)
elif op not in ("barrier", "delay"):
for q in inst.qubits:
q_i = circuit.find_bit(q).index
active_qubits.discard(q_i)
if not active_cbits or not active_qubits:
break
return mapping
def _preprocess_circuit(circuit: QuantumCircuit):
"""Preprocess a circuit to extract measurement mapping.
Args:
circuit (QuantumCircuit): Circuit with final measurements.
Returns:
tuple:
QuantumCircuit: Circuit with final measurements removed.
list[int]: Sorted measured qubit indices.
list[_MeasureInfo]: Measurement metadata per classical register.
"""
mapping = _final_measurement_mapping(circuit)
qargs = sorted(set(mapping.values()))
qargs_index = {q: i for i, q in enumerate(qargs)}
unitary_circ = circuit.remove_final_measurements(inplace=False)
# Keep classical-register bit order for masks.
by_reg: dict[str, list[tuple[int, int]]] = {creg.name: [] for creg in circuit.cregs}
for (creg, offset), q in mapping.items():
by_reg[creg.name].append((offset, qargs_index[q])) # (lsb_index_in_creg, joint_index)
meas_info: list[_MeasureInfo] = []
for name, pairs in by_reg.items():
if not pairs:
continue
pairs.sort(key=lambda t: t[0]) # LSB-first
mask = [joint for (_, joint) in pairs] # mask in LSB order
meas_info.append(_MeasureInfo(creg_name=name, num_bits=len(mask), qreg_indices=mask))
return unitary_circ, qargs, meas_info
# ---------------------- PubResult subclass with safe join_data ----------------------
class _ExactSamplerPubResult(SamplerPubResult):
"""SamplerPubResult variant whose join_data() understands ExactProbArray and
ExactProbNDArray containers."""
def join_data(self, names: Iterable[str] | None = None):
"""Join named per-register probability containers.
Args:
names (Iterable[str] | None): Register names to join.
Returns:
ExactProbArray | ExactProbNDArray: Concatenated bit container.
Raises:
ValueError: If names are empty or missing.
"""
if names is None:
names = list(self.metadata.get("names", []))
names = list(names)
if not names:
raise ValueError("names is empty")
for n in names:
if not hasattr(self.data, n):
raise ValueError(f"name does not exist: {n}")
shape = self.data.shape
if shape == ():
# Scalar: concatenate and return a single ExactProbArray
items: list[ExactProbArray] = []
for n in names:
field = getattr(self.data, n)
items.append(field) # field is ExactProbArray
return ExactProbArray.concatenate_bits(items)
# ND case: build an ndarray of ExactProbArray and return a wrapper
out = np.empty(shape, dtype=object)
for idx in np.ndindex(shape):
items: list[ExactProbArray] = [] # type: ignore
for n in names:
field = getattr(self.data, n) # can be ExactProbNDArray
field_elem = field[idx] if isinstance(field, ExactProbNDArray) else field[idx]
items.append(field_elem)
out[idx] = ExactProbArray.concatenate_bits(items)
return ExactProbNDArray(out)