Source code for pytket.circuit.decompose_classical

# Copyright Quantinuum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for decomposing Circuits containing classical expressions
in to primitive logical operations."""

from heapq import heappop, heappush
from typing import Any, Generic, TypeVar

from pytket.circuit import (
    Circuit,
    ClBitVar,
    ClExpr,
    ClExprOp,
    ClOp,
    ClRegVar,
    Conditional,
    OpType,
    WiredClExpr,
)
from pytket.unit_id import (
    _TEMP_BIT_NAME,
    _TEMP_BIT_REG_BASE,
    _TEMP_REG_SIZE,
    Bit,
    BitRegister,
)
from pytket.circuit.clexpr import check_register_alignments, has_reg_output
from pytket.circuit.logic_exp import Constant, Variable

T = TypeVar("T")


[docs] class DecomposeClassicalError(Exception): """Error with decomposing classical operations."""
[docs] class VarHeap(Generic[T]): """A generic heap implementation.""" def __init__(self) -> None: self._heap: list[T] = [] self._heap_vars: set[T] = set()
[docs] def pop(self) -> T: """Pop from top of heap.""" return heappop(self._heap)
[docs] def push(self, var: T) -> None: """Push var to heap.""" heappush(self._heap, var) self._heap_vars.add(var)
[docs] def is_heap_var(self, var: T) -> bool: """Check if var was generated from heap.""" return var in self._heap_vars
[docs] def fresh_var(self) -> T: """Generate new variable.""" raise NotImplementedError
[docs] class BitHeap(VarHeap[Bit]): """Heap of temporary Bits.""" def __init__(self, _reg_name: str = _TEMP_BIT_NAME): """Initialise new BitHeap. :param _reg_name: Name for register of Bits, defaults to _TEMP_BIT_NAME :type _reg_name: str, optional """ self.reg_name = _reg_name super().__init__() @property def next_index(self) -> int: """Next available bit index, not used by any other heap bit.""" return max((b.index[0] for b in self._heap_vars), default=-1) + 1
[docs] def fresh_var(self) -> Bit: """Return Bit, from heap if available, otherwise create new.""" if self._heap: return self.pop() new_bit = Bit(self.reg_name, self.next_index) self._heap_vars.add(new_bit) return new_bit
[docs] class RegHeap(VarHeap[BitRegister]): """Heap of temporary BitRegisters.""" def __init__(self, _reg_name_base: str = _TEMP_BIT_REG_BASE): """Initialise new RegHeap. :param _reg_name_base: base string for register names, defaults to _TEMP_BIT_REG_BASE :type _reg_name_base: str, optional """ self._reg_name_base = _reg_name_base super().__init__() @property def next_index(self) -> int: """Next available bit index, not used by any other heap register.""" return ( max((int(b.name.split("_")[-1]) for b in self._heap_vars), default=-1) + 1 )
[docs] def fresh_var(self, size: int = _TEMP_REG_SIZE) -> BitRegister: """Return BitRegister, from heap if available, otherwise create new. Optionally set size of created register.""" if self._heap: return self.pop() new_reg = BitRegister(f"{self._reg_name_base}_{self.next_index}", size) self._heap_vars.add(new_reg) return new_reg
[docs] def temp_reg_in_args(args: list[Bit]) -> BitRegister | None: """If there are bits from a temporary register in the args, return it.""" temp_reg_bits = [b for b in args if b.reg_name.startswith(_TEMP_BIT_REG_BASE)] if temp_reg_bits: return BitRegister(temp_reg_bits[0].reg_name, _TEMP_REG_SIZE) return None
VarType = TypeVar("VarType", type[Bit], type[BitRegister]) def _int_to_bools(val: Constant, width: int) -> list[bool]: # map int to bools via litle endian encoding return list(map(bool, map(int, reversed(f"{val:0{width}b}"[-width:])))) def _get_bit_width(x: int) -> int: assert x >= 0 c = 0 while x: x >>= 1 c += 1 return c class _ClExprDecomposer: def __init__( # noqa: PLR0913 self, circ: Circuit, bit_posn: dict[int, int], reg_posn: dict[int, list[int]], args: list[Bit], bit_heap: BitHeap, reg_heap: RegHeap, kwargs: dict[str, Any], ): self.circ: Circuit = circ self.bit_posn: dict[int, int] = bit_posn self.reg_posn: dict[int, list[int]] = reg_posn self.args: list[Bit] = args self.bit_heap: BitHeap = bit_heap self.reg_heap: RegHeap = reg_heap self.kwargs: dict[str, Any] = kwargs # Construct maps from int (i.e. ClBitVar) to Bit, and from int (i.e. ClRegVar) # to BitRegister: self.bit_vars = {i: args[p] for i, p in bit_posn.items()} self.reg_vars = { i: BitRegister(args[p[0]].reg_name, len(p)) for i, p in reg_posn.items() } def add_var(self, var: Variable) -> None: """Add a Bit or BitRegister to the circuit if not already present.""" if isinstance(var, Bit): self.circ.add_bit(var, reject_dups=False) else: assert isinstance(var, BitRegister) for bit in var.to_list(): self.circ.add_bit(bit, reject_dups=False) def set_bits(self, var: Variable, val: int) -> None: """Set the value of a Bit or BitRegister.""" assert val >= 0 if isinstance(var, Bit): assert val >> 1 == 0 self.circ.add_c_setbits([bool(val)], [var], **self.kwargs) else: assert isinstance(var, BitRegister) assert val >> var.size == 0 self.circ.add_c_setreg(val, var, **self.kwargs) def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable: # noqa: PLR0912, PLR0915 """Add the decomposed expression to the circuit and return the Bit or BitRegister that contains the result. :param expr: the expression to decompose :param out_var: where to put the output (if None, create a new scratch location) """ op: ClOp = expr.op heap: VarHeap = self.reg_heap if has_reg_output(op) else self.bit_heap # Eliminate (recursively) subsidiary expressions from the arguments, and convert # all terms to Bit or BitRegister: terms: list[Variable] = [] for arg in expr.args: if isinstance(arg, int): # Assign to a fresh variable fresh_var = heap.fresh_var() self.add_var(fresh_var) self.set_bits(fresh_var, arg) terms.append(fresh_var) elif isinstance(arg, ClBitVar): terms.append(self.bit_vars[arg.index]) elif isinstance(arg, ClRegVar): terms.append(self.reg_vars[arg.index]) else: assert isinstance(arg, ClExpr) terms.append(self.decompose_expr(arg, None)) # Enable reuse of temporary terms: for term in terms: if heap.is_heap_var(term): heap.push(term) if out_var is None: out_var = heap.fresh_var() self.add_var(out_var) match op: case ClOp.BitAnd: self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitNot: self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitOne: assert isinstance(out_var, Bit) self.circ.add_c_setbits([True], [out_var], **self.kwargs) case ClOp.BitOr: self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitXor: self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitZero: assert isinstance(out_var, Bit) self.circ.add_c_setbits([False], [out_var], **self.kwargs) case ClOp.RegAnd: self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegNot: self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegOne: assert isinstance(out_var, BitRegister) self.circ.add_c_setbits( [True] * out_var.size, out_var.to_list(), **self.kwargs ) case ClOp.RegOr: self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegXor: self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegZero: assert isinstance(out_var, BitRegister) self.circ.add_c_setbits( [False] * out_var.size, out_var.to_list(), **self.kwargs ) case _: raise DecomposeClassicalError( f"{op} cannot be decomposed to TKET primitives." ) return out_var def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: # noqa: PLR0912, PLR0915 """Rewrite a circuit command-wise, decomposing ClExprOp.""" if not check_register_alignments(circ): raise DecomposeClassicalError("Circuit contains non-register-aligned ClExprOp.") bit_heap = BitHeap() reg_heap = RegHeap() # add already used heap variables to heaps for b in circ.bits: if b.reg_name == _TEMP_BIT_NAME: bit_heap._heap_vars.add(b) # noqa: SLF001 elif b.reg_name.startswith(_TEMP_BIT_REG_BASE): reg_heap._heap_vars.add(BitRegister(b.reg_name, _TEMP_REG_SIZE)) # noqa: SLF001 newcirc = Circuit(0, name=circ.name) for qb in circ.qubits: newcirc.add_qubit(qb) for cb in circ.bits: # lose all temporary bits, add back as required later if not ( cb.reg_name.startswith(_TEMP_BIT_NAME) or cb.reg_name.startswith(_TEMP_BIT_REG_BASE) ): newcirc.add_bit(cb) # targets of predicates that need to be relabelled replace_targets: dict[Variable, Variable] = {} modified = False for command in circ: op = command.op optype = op.type args = command.args kwargs = {} if optype == OpType.Conditional: assert isinstance(op, Conditional) bits = args[: op.width] # check if conditional on previously decomposed expression if len(bits) == 1 and bits[0] in replace_targets: assert isinstance(bits[0], Bit) # this op should encode comparison and value assert op.value in (0, 1) replace_bit = replace_targets[bits[0]] # temporary condition bit is available for reuse bit_heap.push(replace_bit) # type: ignore # write new conditional op kwargs = {"condition_bits": [replace_bit], "condition_value": op.value} else: kwargs = {"condition_bits": bits, "condition_value": op.value} args = args[op.width :] op = op.op optype = op.type if optype == OpType.RangePredicate: target = args[-1] assert isinstance(target, Bit) newcirc.add_bit(target, reject_dups=False) temp_reg = temp_reg_in_args(args) # type: ignore # ensure predicate is reading from correct output register if temp_reg in replace_targets: assert temp_reg is not None new_target = replace_targets[temp_reg] for i, a in enumerate(args): if a.reg_name == temp_reg.name: args[i] = Bit(new_target.name, a.index[0]) # type: ignore # operations conditional on this bit should remain so replace_targets[target] = target elif optype == OpType.ClExpr: assert isinstance(op, ClExprOp) wexpr: WiredClExpr = op.expr expr: ClExpr = wexpr.expr bit_posn = wexpr.bit_posn reg_posn = wexpr.reg_posn output_posn = wexpr.output_posn assert len(output_posn) > 0 output0 = args[output_posn[0]] assert isinstance(output0, Bit) out_var: Variable = ( BitRegister(output0.reg_name, len(output_posn)) if has_reg_output(expr.op) else output0 ) decomposer = _ClExprDecomposer( newcirc, bit_posn, reg_posn, args, # type: ignore bit_heap, reg_heap, kwargs, ) comp_var = decomposer.decompose_expr(expr, out_var) if comp_var != out_var: replace_targets[out_var] = comp_var modified = True continue if optype == OpType.Barrier: # add_gate doesn't work for metaops newcirc.add_barrier(args) else: for arg in args: if isinstance(arg, Bit) and arg not in newcirc.bits: newcirc.add_bit(arg) newcirc.add_gate(op, args, **kwargs) return newcirc, modified