r"""Quantinuum system results.
Includes conversions to traditional distributions over bitstrings if a tagging
convention is used, including conversion to a pytket BackendResult.
Under this convention, tags are assumed to be a name of a bit register unless they fit
the regex pattern `^([a-z][\w_]*)\[(\d+)\]$` (like `my_Reg[12]`) in which case they
are assumed to refer to the nth element of a bit register.
For results of the form ``` result("<register>", value) ``` `value` can be `{0, 1}`,
wherein the register is assumed to be length 1, or lists over those values,
wherein the list is taken to be the value of the entire register.
For results of the form ``` result("<register>[n]", value) ``` `value` can only be
`{0,1}`.
The register is assumed to be at least `n+1` in size and unset
elements are assumed to be `0`.
Subsequent writes to the same register/element in the same shot will overwrite.
To convert to a `BackendResult` all registers must be present in all shots, and register
sizes cannot change between shots.
"""
from __future__ import annotations
import re
from collections import Counter, defaultdict
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal
from typing_extensions import deprecated
if TYPE_CHECKING:
from collections.abc import Iterable
from pytket.backends.backendresult import BackendResult
#: Primitive data types that can be returned by a result
DataPrimitive = int | float | bool
#: Data value that can be returned by a result: a primitive or a list of primitives
DataValue = DataPrimitive | list[DataPrimitive]
TaggedResult = tuple[str, DataValue]
# Pattern to match register index in tag, e.g. "reg[0]"
REG_INDEX_PATTERN = re.compile(r"^([a-z][\w_]*)\[(\d+)\]$")
BitChar = Literal["0", "1"]
[docs]
@dataclass
class QsysShot(Sequence):
"""Results from a single shot execution."""
#: List of tagged results, where each result is a tuple of tag and data value.
entries: list[TaggedResult] = field(default_factory=list)
[docs]
def __init__(self, entries: Iterable[TaggedResult] | None = None):
self.entries = list(entries or [])
[docs]
def append(self, tag: str, data: DataValue) -> None:
self.entries.append((tag, data))
[docs]
def as_dict(self) -> dict[str, DataValue]:
"""Convert results to a dictionary.
For duplicate tags, the last value is used.
Returns:
dict: A dictionary where the keys are the tags and the
values are the data.
Example:
>>> results = QsysShot()
>>> results.append("tag1", 1)
>>> results.append("tag2", 2)
>>> results.append("tag2", 3)
>>> results.as_dict()
{'tag1': 1, 'tag2': 3}
"""
return dict(self.entries)
[docs]
def to_register_bits(self) -> dict[str, str]:
"""Convert results to a dictionary of register bit values."""
reg_bits: dict[str, list[BitChar]] = {}
res_dict = self.as_dict()
# relies on the fact that dict preserves insertion order
for tag, data in res_dict.items():
match = re.match(REG_INDEX_PATTERN, tag)
if match is not None:
reg_name, reg_index_str = match.groups()
reg_index = int(reg_index_str)
if reg_name not in reg_bits:
# Initialize register counts to False
reg_bits[reg_name] = ["0"] * (reg_index + 1)
bitlst = reg_bits[reg_name]
if reg_index >= len(bitlst):
# Extend register counts with "0"
bitlst += ["0"] * (reg_index - len(bitlst) + 1)
bitlst[reg_index] = _cast_primitive_bit(data)
continue
match data:
case list(vs):
reg_bits[tag] = [_cast_primitive_bit(v) for v in vs]
case _:
reg_bits[tag] = [_cast_primitive_bit(data)]
return {reg: "".join(bits) for reg, bits in reg_bits.items()}
[docs]
def __len__(self) -> int:
return len(self.entries)
[docs]
def __getitem__(self, index: int | slice) -> TaggedResult | list[TaggedResult]:
return self.entries[index]
[docs]
def __iter__(self) -> Iterator[TaggedResult]:
return iter(self.entries)
@deprecated("Use QsysShot instead.")
class HResult(QsysShot):
"""Deprecated alias for QsysShot."""
def _cast_primitive_bit(data: DataValue) -> BitChar:
if isinstance(data, int) and data in {0, 1}:
return str(data) # type: ignore[return-value]
msg = f"Expected bit data for register value found {data}"
raise ValueError(msg)
@dataclass
class QsysResult(Sequence):
"""Results accumulated over multiple shots."""
#: List of QsysShot objects, each representing a single shot's results.
results: list[QsysShot]
def __init__(
self, results: Iterable[QsysShot | Iterable[TaggedResult]] | None = None
):
self.results = [
res if isinstance(res, QsysShot) else QsysShot(res) for res in results or []
]
def register_counts(
self, strict_names: bool = False, strict_lengths: bool = False
) -> dict[str, Counter[str]]:
"""Convert results to a dictionary of register counts.
Returns:
dict: A dictionary where the keys are the register names
and the values are the counts of the register bitstrings.
"""
return {
reg: Counter(bitstrs)
for reg, bitstrs in self.register_bitstrings(
strict_lengths=strict_lengths, strict_names=strict_names
).items()
}
def register_bitstrings(
self, strict_names: bool = False, strict_lengths: bool = False
) -> dict[str, list[str]]:
"""Convert results to a dictionary from register name to list of bitstrings over
the shots.
Args:
strict_names: Whether to enforce that all shots have the same
registers.
strict_lengths: Whether to enforce that all register bitstrings have
the same length.
"""
shot_dct: dict[str, list[str]] = defaultdict(list)
for shot in self.results:
bitstrs = shot.to_register_bits()
for reg, bitstr in bitstrs.items():
if (
strict_lengths
and reg in shot_dct
and len(shot_dct[reg][0]) != len(bitstr)
):
msg = "All register bitstrings must have the same length."
raise ValueError(msg)
shot_dct[reg].append(bitstr)
if strict_names and bitstrs.keys() != shot_dct.keys():
msg = "All shots must have the same registers."
raise ValueError(msg)
return dict(shot_dct)
def __len__(self) -> int:
return len(self.results)
def __getitem__(self, index: int | slice) -> QsysShot | list[QsysShot]:
return self.results[index]
def __iter__(self) -> Iterator[QsysShot]:
return iter(self.results)
def to_pytket(self) -> BackendResult:
"""Convert results to a pytket BackendResult.
Returns:
BackendResult: A BackendResult object with the shots.
Raises:
ImportError: If pytket is not installed.
ValueError: If a register's bitstrings have different lengths or not all
registers are present in all shots.
"""
try:
from pytket._tket.unit_id import Bit
from pytket.backends.backendresult import BackendResult
from pytket.utils.outcomearray import OutcomeArray
except ImportError as e:
msg = "Pytket is an optional dependency, install with the `pytket` extra"
raise ImportError(msg) from e
reg_shots = self.register_bitstrings(strict_lengths=True, strict_names=True)
reg_sizes: dict[str, int] = {
reg: len(next(iter(reg_shots[reg]), "")) for reg in reg_shots
}
registers = list(reg_shots.keys())
bits = [Bit(reg, i) for reg in registers for i in range(reg_sizes[reg])]
int_shots = [
int("".join(reg_shots[reg][i] for reg in registers), 2)
for i in range(len(self.results))
]
return BackendResult(
shots=OutcomeArray.from_ints(int_shots, width=len(bits)), c_bits=bits
)
def _collated_shots_iter(self) -> Iterable[dict[str, list[DataValue]]]:
for shot in self.results:
yield shot.collate_tags()
def collated_shots(self) -> list[dict[str, list[DataValue]]]:
"""For each shot generate a dictionary of tags to collated data."""
return list(self._collated_shots_iter())
def collated_counts(self) -> Counter[tuple[tuple[str, str], ...]]:
"""Calculate counts of bit strings for each tag by collating across shots using
`QsysResult.tag_collated_shots`. Each `result` entry per shot is seen to be
appending to the bitstring for that tag.
If the result value is a list, it is flattened and appended to the bitstring.
Example:
>>> shots = [QsysShot([("a", 1), ("a", 0)]), QsysShot([("a", [0, 1])])]
>>> res = QsysResult(shots)
>>> res.collated_counts()
Counter({(('a', '10'),): 1, (('a', '01'),): 1})
Raises:
ValueError: If any value is a float.
"""
return Counter(
tuple((tag, _flat_bitstring(data)) for tag, data in d.items())
for d in self._collated_shots_iter()
)
@deprecated("Use QsysResult instead.")
class HShots(QsysResult):
"""Deprecated alias for QsysResult."""
def _flat_bitstring(data: Iterable[DataValue]) -> str:
return "".join(_cast_primitive_bit(prim) for prim in _flatten(data))
def _flatten(itr: Iterable[DataValue]) -> Iterable[DataPrimitive]:
for i in itr:
if isinstance(i, list):
yield from _flatten(i)
else:
yield i