# Copyright Quantinuum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import math
import warnings
from collections.abc import Sequence
from functools import partial
import pyqir
from pyqir import IntPredicate, Value
from pytket import predicates
from pytket.circuit import (
BarrierOp,
Bit,
BitRegister,
CircBox,
Circuit,
ClBitVar,
ClExpr,
ClExprOp,
ClOp,
ClRegVar,
Command,
Conditional,
CopyBitsOp,
Op,
OpType,
Qubit,
RangePredicateOp,
SetBitsOp,
UnitID,
WASMOp,
WiredClExpr,
)
from pytket.circuit.logic_exp import (
BitAnd,
BitEq,
BitNeq,
BitNot,
BitOne,
BitOr,
BitWiseOp,
BitXor,
BitZero,
LogicExp,
RegAdd,
RegAnd,
RegEq,
RegGeq,
RegGt,
RegLeq,
RegLsh,
RegLt,
RegMul,
RegNeq,
RegOr,
RegRsh,
RegSub,
RegXor,
)
from pytket.qasm.qasm import _retrieve_registers
from pytket.transform import Transform
from pytket.unit_id import UnitType
from .gatesets import (
FuncSpec,
)
from .module import tketqirModule
_TK_CLOPS_TO_PYQIR_REG: dict = {
RegAnd: lambda b: b.and_,
RegOr: lambda b: b.or_,
RegXor: lambda b: b.xor,
RegAdd: lambda b: b.add,
RegSub: lambda b: b.sub,
RegMul: lambda b: b.mul,
RegLsh: lambda b: b.shl,
RegRsh: lambda b: b.lshr,
}
_TK_CLOPS_TO_PYQIR_REG_BOOL: dict = {
RegEq: lambda b: partial(b.icmp, IntPredicate.EQ),
RegNeq: lambda b: partial(b.icmp, IntPredicate.NE),
RegGt: lambda b: partial(b.icmp, IntPredicate.UGT),
RegGeq: lambda b: partial(b.icmp, IntPredicate.UGE),
RegLt: lambda b: partial(b.icmp, IntPredicate.ULT),
RegLeq: lambda b: partial(b.icmp, IntPredicate.ULE),
}
_TK_CLOPS_TO_PYQIR_2_BITS: dict = {
BitAnd: lambda b: b.and_,
BitOr: lambda b: b.or_,
BitXor: lambda b: b.xor,
BitNeq: lambda b: partial(b.icmp, IntPredicate.NE),
BitEq: lambda b: partial(b.icmp, IntPredicate.EQ),
}
_TK_CLOPS_TO_PYQIR_BIT: dict = {
BitNot: lambda b: b.sub,
}
_TK_CLOPS_TO_PYQIR_2_BITS_NO_PARAM: dict = {
BitOne: 1,
BitZero: 0,
}
_TK_CLEXPR_OP_TO_PYQIR: dict = {
ClOp.BitAnd: lambda b: b.and_,
ClOp.BitOr: lambda b: b.or_,
ClOp.BitXor: lambda b: b.xor,
ClOp.BitEq: lambda b: partial(b.icmp, IntPredicate.EQ),
ClOp.BitNeq: lambda b: partial(b.icmp, IntPredicate.NE),
# ClOp.BitNot
# ClOp.BitZero
# ClOp.BitOne
ClOp.RegAnd: lambda b: b.and_,
ClOp.RegOr: lambda b: b.or_,
ClOp.RegXor: lambda b: b.xor,
ClOp.RegEq: lambda b: partial(b.icmp, IntPredicate.EQ),
ClOp.RegNeq: lambda b: partial(b.icmp, IntPredicate.NE),
# ClOp.RegNot
# ClOp.RegZero
# ClOp.RegOne
ClOp.RegLt: lambda b: partial(b.icmp, IntPredicate.ULT),
ClOp.RegGt: lambda b: partial(b.icmp, IntPredicate.UGT),
ClOp.RegLeq: lambda b: partial(b.icmp, IntPredicate.ULE),
ClOp.RegGeq: lambda b: partial(b.icmp, IntPredicate.UGE),
ClOp.RegAdd: lambda b: b.add,
ClOp.RegSub: lambda b: b.sub,
ClOp.RegMul: lambda b: b.mul,
# ClOp.RegDiv
# ClOp.RegPow
ClOp.RegLsh: lambda b: b.shl,
ClOp.RegRsh: lambda b: b.lshr,
# ClOp.RegNeg
}
_TK_CLEXPR_OP_WITH_REG_ARGS = {
ClOp.RegAnd,
ClOp.RegOr,
ClOp.RegXor,
ClOp.RegEq,
ClOp.RegNeq,
ClOp.RegNot,
ClOp.RegLt,
ClOp.RegGt,
ClOp.RegLeq,
ClOp.RegGeq,
ClOp.RegAdd,
ClOp.RegSub,
ClOp.RegMul,
ClOp.RegDiv,
ClOp.RegPow,
ClOp.RegLsh,
ClOp.RegRsh,
ClOp.RegNeg,
}
[docs]
class AbstractQirGenerator:
"""Abstract Class for the QIR generation from a pytket circuit.
Implementing the functionality that is not specific to any profile"""
def __init__(
self,
circuit: Circuit,
module: tketqirModule,
wasm_int_type: int,
qir_int_type: int,
) -> None:
self.circuit = circuit
self.module = module
self.wasm_int_type = pyqir.IntType(self.module.context, wasm_int_type)
self.int_size = qir_int_type
self.qir_int_type = pyqir.IntType(self.module.context, qir_int_type)
self.qir_i32_type = pyqir.IntType(self.module.context, 32)
self.qir_i64_type = pyqir.IntType(self.module.context, 64)
self.qir_i1p_type = pyqir.PointerType(pyqir.IntType(self.module.context, 1))
self.qir_bool_type = pyqir.IntType(self.module.context, 1)
self.qubit_type = pyqir.qubit_type(self.module.context)
self.result_type = pyqir.result_type(self.module.context)
self.cregs = _retrieve_registers(self.circuit.bits, BitRegister) # type: ignore
self.creg_size: dict[str, int] = {}
self.target_gateset = self.module.gateset.base_gateset
self.block_count = 0
self.block_count_sb = 0
self.has_wasm = False
self.wasm_sar_dict: dict[str, str] = {}
self.azure_sar_dict: dict[str, str] = {}
self.wasm_sar_dict["!llvm.module.flags"] = (
'attributes #1 = { "wasm" }\n\n!llvm.module.flags'
)
self.wasm_sar_dict[
'attributes #1 = { "irreversible" }\n\nattributes #1 = { "wasm" }'
] = 'attributes #1 = { "wasm" }\nattributes #2 = { "irreversible" }'
self.wasm_sar_dict[
"declare void @__quantum__qis__mz__body(%Qubit*, %Result* writeonly) #1"
] = "declare void @__quantum__qis__mz__body(%Qubit*, %Result* writeonly) #2"
self.int_type_str = f"i{qir_int_type}"
self.target_gateset.add(OpType.PhasedX)
self.target_gateset.add(OpType.ZZPhase)
self.target_gateset.add(OpType.ZZMax)
self.target_gateset.add(OpType.TK2)
self.reg_const: dict[str, Value] = {}
self.getset_predicate = predicates.GateSetPredicate(
set(self.target_gateset),
)
# __quantum__qis__read_result__body(result)
self.read_bit_from_result = self.module.module.add_external_function(
"__quantum__qis__read_result__body",
pyqir.FunctionType(
pyqir.IntType(self.module.module.context, 1),
[pyqir.result_type(self.module.module.context)],
),
)
# void __quantum__rt__int_record_output(i64)
self.record_output_i64 = self.module.module.add_external_function(
"__quantum__rt__int_record_output",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
pyqir.IntType(self.module.module.context, qir_int_type),
pyqir.PointerType(pyqir.IntType(self.module.module.context, 8)),
],
),
)
self.barrier: list[pyqir.Function | None] = [None] * (self.circuit.n_qubits + 1)
self.order: list[pyqir.Function | None] = [None] * (self.circuit.n_qubits + 1)
self.group: list[pyqir.Function | None] = [None] * (self.circuit.n_qubits + 1)
self.sleep: list[pyqir.Function | None] = [None] * (self.circuit.n_qubits + 1)
self.rngseed: pyqir.Function | None = None
self.rngnum_bound: pyqir.Function | None = None
self.rngnum: pyqir.Function | None = None
self.rngindex: pyqir.Function | None = None
self.jobnum: pyqir.Function | None = None
self.wasm: dict[str, pyqir.Function] = {}
self.additional_quantum_gates: dict[OpType, pyqir.Function] = {}
self.rng_bound: Value | None = None
@abc.abstractmethod
def _get_bit_from_creg(self, creg: str, index: int) -> Value:
pass
@abc.abstractmethod
def _set_bit_in_creg(self, creg: str, index: int, ssa_bit: Value) -> None:
pass
def _set_bit(self, bit: Bit, ssa_bit: Value) -> None:
self._set_bit_in_creg(bit.reg_name, bit.index[0], ssa_bit)
[docs]
@abc.abstractmethod
def get_ssa_vars(self, reg_name: str) -> Value:
pass
@abc.abstractmethod
def _get_i64_ssa_reg(self, name: str) -> Value:
pass
[docs]
@abc.abstractmethod
def set_ssa_vars(self, reg_name: str, ssa_i64: Value, trunc: bool) -> None:
pass
def _add_barrier_op(self, index: int, qir_qubits: Sequence) -> None:
# __quantum__qis__barrier1__body()
if self.barrier[index] is None:
self.barrier[index] = self.module.module.add_external_function(
f"__quantum__qis__barrier{index}__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[pyqir.qubit_type(self.module.module.context)] * index,
),
)
self.module.builder.call(
self.barrier[index], # type: ignore
[*qir_qubits],
)
def _add_group_op(self, index: int, qir_qubits: Sequence) -> None:
# __quantum__qis__group1__body()
if self.group[index] is None:
self.group[index] = self.module.module.add_external_function(
f"__quantum__qis__group{index}__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[pyqir.qubit_type(self.module.module.context)] * index,
),
)
self.module.builder.call(
self.group[index], # type: ignore
[*qir_qubits],
)
def _add_order_op(self, index: int, qir_qubits: Sequence) -> None:
# __quantum__qis__order1__body()
if self.order[index] is None:
self.order[index] = self.module.module.add_external_function(
f"__quantum__qis__order{index}__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[pyqir.qubit_type(self.module.module.context)] * index,
),
)
self.module.builder.call(
self.order[index], # type: ignore
[*qir_qubits],
)
def _add_sleep_op(self, index: int, qir_qubits: Sequence, duration: float) -> None:
# __quantum__qis__sleep__body()
if index > 1:
raise ValueError("Sleep operation only allowed on one qubit")
if self.sleep[index] is None:
paramlist = [pyqir.qubit_type(self.module.module.context)] * index
paramlist.append(
pyqir.Type.double(self.module.module.context),
) # add float parameter
self.sleep[index] = self.module.module.add_external_function(
"__quantum__qis__sleep__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
paramlist,
),
)
self.module.builder.call(
self.sleep[index], # type: ignore
[
*qir_qubits,
pyqir.const(pyqir.Type.double(self.module.module.context), duration),
],
)
def _add_rngbound_op(self, qir_creg: Value) -> None:
self.rng_bound = qir_creg
def _add_rngseed_op(self, qir_creg: Value) -> None:
# void ___random_seed(i64)
if self.rngseed is None:
self.rngseed = self.module.module.add_external_function(
"___random_seed",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[self.qir_i64_type],
),
)
if self.int_size != 64: # noqa: PLR2004
qir_creg_i64 = self.module.module.builder.zext(qir_creg, self.qir_i64_type)
self.module.builder.call(
self.rngseed,
[
qir_creg_i64,
],
)
else:
self.module.builder.call(
self.rngseed,
[
qir_creg,
],
)
def _add_rngnum_op(self, qir_creg_name: str) -> None:
if self.rng_bound is None:
# i32 ___random_int()
if self.rngnum is None:
self.rngnum = self.module.module.add_external_function(
"___random_int",
pyqir.FunctionType(
self.qir_i32_type,
[],
),
)
rng_num = self.module.builder.call(
self.rngnum,
[],
)
if self.int_size != 32: # noqa: PLR2004
rng_num_i64 = self.module.module.builder.zext(
rng_num, self.qir_int_type
)
self.set_ssa_vars(qir_creg_name, rng_num_i64, False)
else:
self.set_ssa_vars(qir_creg_name, rng_num, False)
else:
# i32 ___random_int_bounded(i32)
if self.rngnum_bound is None:
self.rngnum_bound = self.module.module.add_external_function(
"___random_int_bounded",
pyqir.FunctionType(
self.qir_i32_type,
[self.qir_i32_type],
),
)
if self.int_size != 32: # noqa: PLR2004
bound = self.module.module.builder.trunc(
self.rng_bound, self.qir_i32_type
)
rng_num = self.module.builder.call(
self.rngnum_bound,
[bound],
)
rng_num_i64 = self.module.module.builder.zext(
rng_num, self.qir_int_type
)
self.set_ssa_vars(qir_creg_name, rng_num_i64, False)
else:
rng_num = self.module.builder.call(
self.rngnum_bound,
[self.rng_bound],
)
self.set_ssa_vars(qir_creg_name, rng_num, False)
def _add_rngindex_op(self, qir_creg: Value) -> None:
# void ___set_random_index(i32)
if self.rngindex is None:
self.rngindex = self.module.module.add_external_function(
"___set_random_index",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[self.qir_i32_type],
),
)
if self.int_size != 32: # noqa: PLR2004
index = self.module.module.builder.trunc(qir_creg, self.qir_i32_type)
self.module.builder.call(
self.rngindex,
[
index,
],
)
else:
self.module.builder.call(
self.rngindex,
[
qir_creg,
],
)
def _add_jobnum_op(self, qir_creg: str) -> None:
# i64 ___get_current_shot()
if self.jobnum is None:
self.jobnum = self.module.module.add_external_function(
"___get_current_shot",
pyqir.FunctionType(
self.qir_i64_type,
[],
),
)
if self.int_size != 64: # noqa: PLR2004
jobnum = self.module.builder.call(
self.jobnum,
[],
)
jobnum32 = self.module.module.builder.trunc(jobnum, self.qir_i32_type)
self.set_ssa_vars(qir_creg, jobnum32, False)
else:
jobnum = self.module.builder.call(
self.jobnum,
[],
)
self.set_ssa_vars(qir_creg, jobnum, False)
def _decompose_conditional_circ_box(
self,
op: CircBox,
args: list[UnitID],
) -> Circuit:
"""Rebase an op to the target gateset if needed."""
circuit = Circuit(self.circuit.n_qubits)
arg_names = {b.reg_name for b in args if type(b) is Bit}
for cr_name in arg_names:
circuit.add_c_register(self.circuit.get_c_register(cr_name))
circuit.add_circbox(op, args)
Transform.DecomposeBoxes().apply(circuit)
return circuit
def _get_optype_and_params(self, op: Op) -> tuple[OpType, Sequence[float]]:
optype: OpType = op.type
params: list = []
if optype in [OpType.ExplicitPredicate, OpType.Barrier, OpType.CopyBits]:
pass
else:
params = op.params
return (optype, params)
def _to_qis_qubits(self, qubits: list[Qubit]) -> list[Value]:
return [self.module.module.qubits[qubit.index[0]] for qubit in qubits]
def _to_qis_results(self, bits: list[Bit]) -> Value | None:
if bits:
return self.module.module.results[bits[0].index[0]]
return None
def _to_qis_bits(self, args: list[Bit]) -> Sequence[Value]:
for b in args:
assert b.reg_name == "c"
if args:
return [self.module.module.results[bit.index[0]] for bit in args[:-1]]
return []
def _get_c_regs_from_com(
self,
op: Op,
args: list[Bit | Qubit],
) -> tuple[list[str], list[str]]:
"""Get classical registers from command op types."""
inputs: list[str] = []
outputs: list[str] = []
if isinstance(op, WASMOp):
for reglist, sizes in [
(inputs, op.input_widths),
(outputs, op.output_widths),
]:
for in_width in sizes:
assert in_width > 0
com_bits = args[:in_width]
args = args[in_width:]
regname = com_bits[0].reg_name
if com_bits != list(self.cregs[regname]):
raise ValueError("WASM ops must act on entire registers.")
reglist.append(regname)
return inputs, outputs
def _get_ssa_from_cl_reg_op(
self,
reg: BitRegister | RegAnd | RegOr | RegXor | int,
) -> Value:
if type(reg) in _TK_CLOPS_TO_PYQIR_REG:
assert len(reg.args) == 2 # type: ignore # noqa: PLR2004
ssa_left = self._get_ssa_from_cl_reg_op(reg.args[0]) # type: ignore
ssa_right = self._get_ssa_from_cl_reg_op(reg.args[1]) # type: ignore
# add function to module
output_instruction = _TK_CLOPS_TO_PYQIR_REG[type(reg)](self.module.builder)(
ssa_left,
ssa_right,
)
return output_instruction # type: ignore # noqa: RET504
if type(reg) is BitRegister:
return self._get_i64_ssa_reg(reg.name)
if type(reg) is int:
return pyqir.const(self.qir_int_type, reg)
raise ValueError(f"unsupported classical register operation: {type(reg)}")
def _get_ssa_from_cl_bit_op(
self,
bit: LogicExp | Bit | BitAnd | BitOr | BitXor | int,
) -> Value:
if type(bit) is Bit:
result = self._get_bit_from_creg(bit.reg_name, bit.index[0])
return result # noqa: RET504
if type(bit) is int:
return pyqir.const(self.qir_bool_type, bit)
if type(bit) in _TK_CLOPS_TO_PYQIR_BIT:
assert len(bit.args) == 1 # type: ignore
ssa_left = pyqir.const(self.qir_bool_type, 1)
ssa_right = self._get_ssa_from_cl_bit_op(bit.args[0]) # type: ignore
# add function to module
output_instruction = _TK_CLOPS_TO_PYQIR_BIT[type(bit)](self.module.builder)(
ssa_left,
ssa_right,
)
return output_instruction # type: ignore # noqa: RET504
if type(bit) in _TK_CLOPS_TO_PYQIR_2_BITS:
assert len(bit.args) == 2 # type: ignore # noqa: PLR2004
ssa_left = self._get_ssa_from_cl_bit_op(bit.args[0]) # type: ignore
ssa_right = self._get_ssa_from_cl_bit_op(bit.args[1]) # type: ignore
# add function to module
output_instruction = _TK_CLOPS_TO_PYQIR_2_BITS[type(bit)](
self.module.builder,
)(ssa_left, ssa_right)
return output_instruction # type: ignore # noqa: RET504
raise ValueError(f"unsupported bitwise operation {type(bit)}")
[docs]
def get_wasm_sar(self) -> dict[str, str]:
return self.wasm_sar_dict
[docs]
def get_azure_sar(self) -> dict[str, str]:
return self.azure_sar_dict
[docs]
def conv_RangePredicateOp(self, op: RangePredicateOp, args: list[Bit]) -> None:
# special case handling for REG_EQ
if op.lower == op.upper:
registername = args[0].reg_name
result = self.module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.lower),
self._get_i64_ssa_reg(registername),
)
self._set_bit(args[-1], result)
else:
lower_qir = pyqir.const(self.qir_int_type, op.lower)
upper_qir = pyqir.const(self.qir_int_type, op.upper)
registername = args[0].reg_name
lower_cond = self.module.module.builder.icmp(
pyqir.IntPredicate.SGT,
lower_qir,
self._get_i64_ssa_reg(registername),
)
upper_cond = self.module.module.builder.icmp(
pyqir.IntPredicate.SGT,
self._get_i64_ssa_reg(registername),
upper_qir,
)
result = self.module.module.builder.and_(lower_cond, upper_cond)
self._set_bit(args[-1], result)
[docs]
@abc.abstractmethod
def conv_conditional(self, command: Command, op: Conditional) -> None:
pass
[docs]
def conv_WASMOp(self, op: WASMOp, args: list[Bit | Qubit]) -> None:
self.has_wasm = True
paramreg, resultreg = self._get_c_regs_from_com(op, args)
ssa_param = [self._get_i64_ssa_reg(p) for p in paramreg]
if op.func_name not in self.wasm:
wasm_func_interface = "declare "
parametertype = [self.qir_int_type] * len(paramreg)
if len(resultreg) == 0:
returntype = pyqir.Type.void(self.module.module.context)
wasm_func_interface += "void "
elif len(resultreg) == 1:
returntype = self.qir_int_type
wasm_func_interface += f"{self.int_type_str} "
else:
raise ValueError(
"wasm function which return more than"
" one value are not supported yet"
f"please don't use {op.func_name}",
)
self.wasm[op.func_name] = self.module.module.add_external_function(
f"{op.func_name}",
pyqir.FunctionType(
returntype,
parametertype,
),
)
wasm_func_interface += f"@{op.func_name}("
if len(paramreg) > 0:
param_str = f"{self.int_type_str}, " * (len(paramreg) - 1)
wasm_func_interface += param_str
wasm_func_interface += f"{self.int_type_str})"
else:
wasm_func_interface += ")"
self.wasm_sar_dict[wasm_func_interface] = f"{wasm_func_interface} #1"
result = self.module.builder.call(
self.wasm[op.func_name],
[*ssa_param],
)
if len(resultreg) == 1:
self.set_ssa_vars(resultreg[0], result, True)
[docs]
def conv_ZZPhase(self, qubits: list[Qubit], op: Op) -> None:
if OpType.ZZPhase not in self.additional_quantum_gates:
self.additional_quantum_gates[OpType.ZZPhase] = (
self.module.module.add_external_function(
"__quantum__qis__rzz__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
pyqir.Type.double(self.module.module.context),
pyqir.qubit_type(self.module.module.context),
pyqir.qubit_type(self.module.module.context),
],
),
)
)
self.module.builder.call(
self.additional_quantum_gates[OpType.ZZPhase],
[
pyqir.const(
pyqir.Type.double(self.module.module.context),
(float(op.params[0]) * math.pi),
),
self.module.module.qubits[qubits[0].index[0]],
self.module.module.qubits[qubits[1].index[0]],
],
)
[docs]
def conv_phasedx(self, qubits: list[Qubit], op: Op) -> None:
if OpType.PhasedX not in self.additional_quantum_gates:
self.additional_quantum_gates[OpType.PhasedX] = (
self.module.module.add_external_function(
"__quantum__qis__phasedx__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
pyqir.Type.double(self.module.module.context),
pyqir.Type.double(self.module.module.context),
pyqir.qubit_type(self.module.module.context),
],
),
)
)
self.module.builder.call(
self.additional_quantum_gates[OpType.PhasedX],
[
pyqir.const(
pyqir.Type.double(self.module.module.context),
(float(op.params[0]) * math.pi),
),
pyqir.const(
pyqir.Type.double(self.module.module.context),
(float(op.params[1]) * math.pi),
),
self.module.module.qubits[qubits[0].index[0]],
],
)
[docs]
def conv_tk2(self, qubits: list[Qubit], op: Op) -> None:
if OpType.TK2 not in self.additional_quantum_gates:
self.additional_quantum_gates[OpType.TK2] = (
self.module.module.add_external_function(
"__quantum__qis__rxxyyzz__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
pyqir.Type.double(self.module.module.context),
pyqir.Type.double(self.module.module.context),
pyqir.Type.double(self.module.module.context),
pyqir.qubit_type(self.module.module.context),
pyqir.qubit_type(self.module.module.context),
],
),
)
)
self.module.builder.call(
self.additional_quantum_gates[OpType.TK2],
[
pyqir.const(
pyqir.Type.double(self.module.module.context),
(float(op.params[0]) * math.pi),
),
pyqir.const(
pyqir.Type.double(self.module.module.context),
(float(op.params[1]) * math.pi),
),
pyqir.const(
pyqir.Type.double(self.module.module.context),
(float(op.params[2]) * math.pi),
),
self.module.module.qubits[qubits[0].index[0]],
self.module.module.qubits[qubits[1].index[0]],
],
)
[docs]
def conv_zzmax(self, qubits: list[Qubit]) -> None:
if OpType.ZZMax not in self.additional_quantum_gates:
self.additional_quantum_gates[OpType.ZZMax] = (
self.module.module.add_external_function(
"__quantum__qis__zzmax__body",
pyqir.FunctionType(
pyqir.Type.void(self.module.module.context),
[
pyqir.qubit_type(self.module.module.context),
pyqir.qubit_type(self.module.module.context),
],
),
)
)
self.module.builder.call(
self.additional_quantum_gates[OpType.ZZMax],
[
self.module.module.qubits[qubits[0].index[0]],
self.module.module.qubits[qubits[1].index[0]],
],
)
[docs]
@abc.abstractmethod
def conv_measure(self, bits: list[Bit], qubits: list[Qubit]) -> None:
pass
def _get_ssa_from_clexpr_arg( # noqa: PLR0911, PLR0912
self,
arg: int | ClBitVar | ClRegVar | ClExpr,
bit_posn: dict[int, int],
reg_posn: dict[int, list[int]],
cmd_args: list[Bit],
expect_register: bool,
) -> tuple[bool, Value]:
if isinstance(arg, int):
if expect_register:
return False, pyqir.const(self.qir_int_type, arg)
if arg not in (0, 1):
raise ValueError(f"Invalid bit value {arg}")
return True, pyqir.const(self.qir_bool_type, arg)
if isinstance(arg, ClBitVar):
cmd_arg: Bit = cmd_args[bit_posn[arg.index]]
return True, self._get_bit_from_creg(cmd_arg.reg_name, cmd_arg.index[0])
if isinstance(arg, ClRegVar):
return False, self._get_i64_ssa_reg(
cmd_args[reg_posn[arg.index][0]].reg_name,
)
assert isinstance(arg, ClExpr)
expr_op: ClOp = arg.op
ssa_args = [
self._get_ssa_from_clexpr_arg(
expr_arg,
bit_posn,
reg_posn,
cmd_args,
expr_op in _TK_CLEXPR_OP_WITH_REG_ARGS,
)[1]
for expr_arg in arg.args
]
match expr_op:
case ClOp.BitAnd | ClOp.BitOr | ClOp.BitXor | ClOp.BitEq | ClOp.BitNeq:
return True, _TK_CLEXPR_OP_TO_PYQIR[expr_op](self.module.builder)(
*ssa_args,
)
case ClOp.BitNot:
# Implemented as x --> 1 - x
assert len(ssa_args) == 1
return True, self.module.builder.sub(
pyqir.const(self.qir_bool_type, 1),
ssa_args[0],
)
case ClOp.BitZero:
assert len(ssa_args) == 0
return True, pyqir.const(self.qir_bool_type, 0)
case ClOp.BitOne:
assert len(ssa_args) == 0
return True, pyqir.const(self.qir_bool_type, 1)
case (
ClOp.RegAnd
| ClOp.RegOr
| ClOp.RegXor
| ClOp.RegAdd
| ClOp.RegSub
| ClOp.RegMul
| ClOp.RegLsh
| ClOp.RegRsh
):
assert len(ssa_args) == 2 # noqa: PLR2004
return False, _TK_CLEXPR_OP_TO_PYQIR[expr_op](self.module.builder)(
*ssa_args,
)
case (
ClOp.RegEq
| ClOp.RegNeq
| ClOp.RegLt
| ClOp.RegGt
| ClOp.RegLeq
| ClOp.RegGeq
):
assert len(ssa_args) == 2 # noqa: PLR2004
return True, _TK_CLEXPR_OP_TO_PYQIR[expr_op](self.module.builder)(
*ssa_args,
)
case ClOp.RegNot:
# Implemented as x --> 2^self.int_size - 1 - x
assert len(ssa_args) == 1
return False, self.module.builder.sub(
pyqir.const(self.qir_int_type, (1 << self.int_size) - 1),
ssa_args[0],
)
case ClOp.RegZero:
assert len(ssa_args) == 0
return False, pyqir.const(self.qir_int_type, 0)
case ClOp.RegOne:
# Sets all bits in the register to 1
assert len(ssa_args) == 0
return False, pyqir.const(
self.qir_int_type,
(1 << self.int_size) - 1,
)
case ClOp.RegNeg:
# Implemented as x --> 0 - x
assert len(ssa_args) == 1
return False, self.module.builder.sub(
pyqir.const(self.qir_int_type, 0),
ssa_args[0],
)
case ClOp.RegDiv | ClOp.RegPow:
# https://github.com/Quantinuum/pytket-qir/issues/181
raise ValueError(f"Classical operation {expr_op} not supported")
case _:
raise ValueError("Invalid classical operation")
[docs]
def conv_clexprop(self, op: ClExprOp, args: list[Bit]) -> None:
wexpr: WiredClExpr = op.expr
expr: ClExpr = wexpr.expr
bit_posn: dict[int, int] = wexpr.bit_posn
reg_posn: dict[int, list[int]] = wexpr.reg_posn
output_posn: list[int] = wexpr.output_posn
# We require that all register variables correspond to actual complete
# registers.
input_regs: dict[int, BitRegister] = {}
all_cregs = set(self.cregs.values())
for i, posns in reg_posn.items():
reg_args = [args[j] for j in posns]
for creg in all_cregs:
if creg.to_list() == reg_args:
input_regs[i] = creg
break
else:
raise ValueError(
f"ClExprOp ({wexpr}) contains a register variable (r{i}) "
"that is not wired to any BitRegister in the circuit.",
)
returntypebool, output_instruction = self._get_ssa_from_clexpr_arg(
expr,
bit_posn,
reg_posn,
args,
expr.op in _TK_CLEXPR_OP_WITH_REG_ARGS,
)
if returntypebool:
assert len(output_posn) == 1
output_arg: Bit = args[output_posn[0]]
self._set_bit_in_creg(
output_arg.reg_name,
output_arg.index[0],
output_instruction,
)
else:
assert len(output_posn) > 0
output_args: list[Bit] = [args[i] for i in output_posn]
output_reg_name: str | None = None
for creg in all_cregs:
if creg.to_list() == output_args:
output_reg_name = creg.name
break
else:
raise ValueError(
f"ClExprOp ({wexpr}) has outputs that do not "
"correspond to any BitRegister in the circuit.",
)
self.set_ssa_vars(output_reg_name, output_instruction, True)
[docs]
def conv_SetBitsOp(self, bits: list[Bit], op: SetBitsOp) -> None:
assert len(op.values) == len(bits)
for b, v in zip(bits, op.values, strict=False):
output_instruction = pyqir.const(self.qir_bool_type, int(v))
self._set_bit(b, output_instruction)
[docs]
def conv_CopyBitsOp(self, args: list) -> None:
assert len(args) % 2 == 0
half_length = len(args) // 2
for i, o in zip(args[:half_length], args[half_length:], strict=False):
output_instruction = self._get_bit_from_creg(i.reg_name, i.index[0])
self._set_bit(o, output_instruction)
[docs]
def conv_BarrierOp(self, qubits: list[Qubit], op: BarrierOp) -> None:
assert qubits[0].reg_name == "q"
qir_qubits = self._to_qis_qubits(qubits)
if op.data[0:5] == "order":
self._add_order_op(len(qubits), qir_qubits)
elif op.data[0:5] == "group":
self._add_group_op(len(qubits), qir_qubits)
elif op.data[0:5] == "sleep":
self._add_sleep_op(len(qubits), qir_qubits, float(op.data[6:-1]))
else:
# ignore all other op.data
if op.data != "":
warnings.warn(
f"Unknown op data `{op.data}` converted to barrier", stacklevel=2
)
self._add_barrier_op(len(qubits), qir_qubits)
[docs]
def conv_RNGJobOpR(self, optype: OpType, bits: list[Bit]) -> None:
creg_name = bits[0].reg_name
for x in bits:
assert creg_name == x.reg_name
if optype == OpType.RNGNum:
self._add_rngnum_op(creg_name)
elif optype == OpType.JobShotNum:
self._add_jobnum_op(creg_name)
else:
assert not "unexpected op type for RNG"
[docs]
def conv_RNGJobOp(self, optype: OpType, bits: list[Bit]) -> None:
creg_name = bits[0].reg_name
for x in bits:
assert creg_name == x.reg_name
creg = self._get_i64_ssa_reg(creg_name)
if optype == OpType.RNGSeed:
self._add_rngseed_op(creg)
elif optype == OpType.RNGBound:
self._add_rngbound_op(creg)
elif optype == OpType.RNGIndex:
self._add_rngindex_op(creg)
else:
assert not "unexpected op type for RNG"
[docs]
def conv_other(
self,
bits: list[Bit],
qubits: list[Qubit],
op: Op,
args: list,
) -> None:
optype, params = self._get_optype_and_params(op)
pi_params = [p * math.pi for p in params]
qubits_qis = self._to_qis_qubits(qubits)
results = self._to_qis_results(bits)
bits_qis: Sequence[Value] | None = None
if type(optype) is BitWiseOp:
bits_qis = self._to_qis_bits(args)
gate = self.module.gateset.tk_to_gateset(optype)
if gate.func_spec != FuncSpec.BODY:
func_name = gate.func_name.value + "_" + gate.func_spec.value
get_gate = getattr(self.module.qis, func_name)
else:
get_gate = getattr(self.module.qis, gate.func_name.value)
if bits_qis:
get_gate(*bits_qis)
elif params:
get_gate(*pi_params, *qubits_qis)
elif results:
get_gate(*qubits_qis, results)
else:
get_gate(*qubits_qis)
[docs]
def command_to_module(self, op: Op, args: list) -> tketqirModule: # noqa: PLR0912
"""Populate a PyQir module from a pytket command."""
qubits = [q for q in args if q.type == UnitType.qubit]
bits = [b for b in args if b.type == UnitType.bit]
if isinstance(op, RangePredicateOp):
self.conv_RangePredicateOp(op, args)
elif isinstance(op, Conditional):
raise ValueError("conditional ops can't contain conditional ops")
elif isinstance(op, WASMOp):
self.conv_WASMOp(op, args)
elif op.type == OpType.ZZPhase:
self.conv_ZZPhase(qubits, op)
elif op.type == OpType.PhasedX:
self.conv_phasedx(qubits, op)
elif op.type == OpType.TK2:
self.conv_tk2(qubits, op)
elif op.type == OpType.ZZMax:
self.conv_zzmax(qubits)
elif op.type == OpType.Measure:
self.conv_measure(bits, qubits)
elif op.type == OpType.Phase:
# ignore phase op
pass
elif op.type in [OpType.RNGNum, OpType.JobShotNum]:
self.conv_RNGJobOpR(op.type, bits)
elif op.type in [
OpType.RNGSeed,
OpType.RNGBound,
OpType.RNGIndex,
]:
self.conv_RNGJobOp(op.type, bits)
elif isinstance(op, ClExprOp):
self.conv_clexprop(op, args)
elif isinstance(op, SetBitsOp):
self.conv_SetBitsOp(bits, op)
elif isinstance(op, CopyBitsOp):
self.conv_CopyBitsOp(args)
elif isinstance(op, BarrierOp):
self.conv_BarrierOp(qubits, op)
else:
self.conv_other(bits, qubits, op, args)
return self.module
[docs]
@abc.abstractmethod
def record_output(self) -> None:
"""Function to record the output"""
[docs]
def circuit_to_module(
self,
circuit: Circuit,
record_output: bool = False,
) -> tketqirModule:
"""Populate a PyQir module from a pytket circuit."""
for command in circuit:
op = command.op
if isinstance(op, Conditional):
self.conv_conditional(command, op)
else:
self.command_to_module(op, command.args)
if record_output:
self.record_output()
return self.module
[docs]
def subcircuit_to_module(
self,
circuit: Circuit,
) -> tketqirModule:
"""Populate a PyQir module from a pytket subcircuit."""
for command in circuit:
self.command_to_module(command.op, command.args)
return self.module