# Copyright Quantinuum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
import pyqir
from pyqir import BasicBlock, Value
from pytket.circuit import (
Bit,
BitRegister,
CircBox,
Circuit,
Command,
Conditional,
OpType,
Qubit,
)
from .module import tketqirModule
from .qirgenerator import (
AbstractQirGenerator,
)
[docs]
class AdaptiveProfileQirGenerator(AbstractQirGenerator):
"""Generate QIR from a pytket circuit."""
def __init__(
self,
circuit: Circuit,
module: tketqirModule,
wasm_int_type: int,
qir_int_type: int,
trunc: bool,
) -> None:
self.trunc = trunc
super().__init__(circuit, module, wasm_int_type, qir_int_type)
self.set_cregs: dict[str, list] = {} # Keep track of set registers.
self.ssa_vars: dict[
str,
list[tuple[Value, BasicBlock]],
] = {} # Keep track of set ssa variables.
self.list_of_changed_cregs: list[str] = []
for creg in self.circuit.c_registers:
reg_name = creg[0].reg_name
self.reg_const[reg_name] = self.module.module.add_byte_string(
str.encode(reg_name),
)
entry = self.module.module.entry_block
self.module.module.builder.insert_at_end(entry)
self.active_block = entry
self.active_block_main = entry
self.active_block_list = [entry]
# set the prefix for the names of the conditional blocks
self.conditional_bp = "condb"
# set the prefix for the names of the continue blocks
self.continue_bp = "contb"
self.prefix_length = 5
assert self.conditional_bp != self.continue_bp
assert self.conditional_bp != "entry"
assert self.continue_bp != "entry"
# the code is assuming that the prefixes have the same length
assert len(self.conditional_bp) == self.prefix_length
assert len(self.continue_bp) == self.prefix_length
for creg in self.circuit.c_registers:
self._reg2ssa_var(creg)
self.list_of_changed_cregs.append(creg.name)
self.creg_size[creg.name] = creg.size
def _get_bit_from_creg(self, creg: str, index: int) -> Value:
ssa_index = pyqir.const(self.qir_int_type, 2**index)
result = self.module.module.builder.icmp(
pyqir.IntPredicate.EQ,
ssa_index,
self.module.module.builder.and_(ssa_index, self.get_ssa_vars(creg)),
)
return result # noqa: RET504
def _set_bit_in_creg_blocks(self, creg: str, index: int, ssa_bit: Value) -> None:
ssa_int = self.get_ssa_vars(creg)
ssa_index = pyqir.const(self.qir_int_type, 2**index)
# it would be better to do an invert here, but that is not
# (yet) available in pyqir
ssa_int_all_1 = pyqir.const(self.qir_int_type, (2 ** (self.int_size - 1) - 1))
entry_point = self.module.module.entry_point
sb_0 = pyqir.BasicBlock(
self.module.module.context,
f"sb_0_{self.block_count_sb}",
entry_point,
)
sb_1 = pyqir.BasicBlock(
self.module.module.context,
f"sb_1_{self.block_count_sb}",
entry_point,
)
continue_block = pyqir.BasicBlock(
self.module.module.context,
f"{self.active_block_main.name}_{self.block_count_sb}",
entry_point,
)
if self.active_block_main.name[0 : self.prefix_length] != self.conditional_bp:
self.active_block_list.append(continue_block)
self.block_count_sb = self.block_count_sb + 1
self.module.module.builder.condbr(ssa_bit, sb_1, sb_0)
# if bit 1
self.module.module.builder.insert_at_end(sb_1)
result_1 = self.module.module.builder.or_(ssa_index, ssa_int)
self.module.module.builder.br(continue_block)
# if bit 0
self.module.module.builder.insert_at_end(sb_0)
result_0 = self.module.module.builder.and_(
self.module.module.builder.xor(
ssa_index,
ssa_int_all_1,
),
ssa_int,
)
self.module.module.builder.br(continue_block)
# phi and continue
self.active_block = continue_block
self.module.module.builder.insert_at_end(continue_block)
phi = self.module.module.builder.phi(self.qir_int_type)
phi.add_incoming(result_0, sb_0)
phi.add_incoming(result_1, sb_1)
self.set_ssa_vars(creg, phi, False)
def _set_bit_in_creg_zext(self, creg: str, index: int, ssa_bit: Value) -> None:
ssa_int = self.get_ssa_vars(creg)
ssa_bit_i64 = self.module.module.builder.zext(ssa_bit, self.qir_int_type)
ssa_index = pyqir.const(self.qir_int_type, 2**index)
# it would be better to do an invert here, but that is not
# (yet) available in pyqir
ssa_int_all_1 = pyqir.const(self.qir_int_type, (2 ** (self.int_size - 1) - 1))
# if ssa_bit is 1, ((BIT) MUL (2^INDEX) ) OR INT
ssa_result_1 = self.module.module.builder.or_(
self.module.module.builder.mul(ssa_bit_i64, ssa_index),
ssa_int,
)
# if ssa_bit is 0, ((2**63-1) XOR ((1-BIT) MUL (2^INDEX))) and INT
ssa_result_0 = self.module.module.builder.and_(
self.module.module.builder.xor(
ssa_int_all_1,
self.module.module.builder.mul(
self.module.module.builder.sub(
pyqir.const(self.qir_int_type, 1),
ssa_bit_i64,
),
ssa_index,
),
),
ssa_result_1,
)
# set ssa
self.set_ssa_vars(creg, ssa_result_0, False)
def _set_bit_in_creg(self, creg: str, index: int, ssa_bit: Value) -> None:
self._set_bit_in_creg_zext(creg, index, ssa_bit)
[docs]
def get_ssa_vars(self, reg_name: str) -> Value:
if reg_name not in self.ssa_vars:
raise ValueError(f"{reg_name} is not a valid register")
return self.ssa_vars[reg_name][-1][0]
def _get_i64_ssa_reg(self, reg_name: str) -> Value:
if reg_name not in self.ssa_vars:
raise ValueError(f"{reg_name} is not a valid register")
return self.ssa_vars[reg_name][-1][0]
[docs]
def get_ssa_list(self, reg_name: str) -> list:
if reg_name not in self.ssa_vars:
raise ValueError(f"{reg_name} is not a valid register")
return self.ssa_vars[reg_name]
[docs]
def set_ssa_vars(self, reg_name: str, ssa_i64: Value, trunc: bool) -> None:
if reg_name not in self.ssa_vars:
raise ValueError(f"{reg_name} is not a valid register")
if self.trunc and trunc and self.creg_size[reg_name] != self.int_size:
type_register = pyqir.IntType(self.module.context, self.creg_size[reg_name])
ssa_i_trunc = self.module.module.builder.trunc(ssa_i64, type_register)
ssa_i64_zext = self.module.module.builder.zext(
ssa_i_trunc,
self.qir_int_type,
)
self.ssa_vars[reg_name].append((ssa_i64_zext, self.active_block))
else:
self.ssa_vars[reg_name].append((ssa_i64, self.active_block))
self.list_of_changed_cregs.append(reg_name)
def _reg2ssa_var(self, bit_reg: BitRegister) -> Value:
"""Convert a BitRegister to an SSA variable using pyqir types."""
reg_name = bit_reg[0].reg_name
if reg_name not in self.ssa_vars:
if len(bit_reg) > self.int_size:
raise ValueError(
f"Classical register should only have the size of {self.int_size}",
)
ssa_var = pyqir.const(self.qir_int_type, 0)
self.ssa_vars[reg_name] = [(ssa_var, self.active_block)]
return ssa_var
return cast("Value", self.ssa_vars[reg_name])
def _add_phi(self) -> None:
"""
add phi nodes for the previously changed registers.
phi requires ssa variables from both predecessor blocks,
these are not necessarily the blocks where the variables
have been set. The second loop searches for the second variable
and adds that with the other predecessor
"""
for creg in set(self.list_of_changed_cregs):
phi = self.module.module.builder.phi(self.qir_int_type)
ssa_list = self.get_ssa_list(creg)
# the first predecessor if the direct previous (last) entry in the ssa list
phi.add_incoming(ssa_list[-1][0], ssa_list[-1][1])
found_second_block = False
# search for the other ssa variable
for i in range(-2, -len(ssa_list) - 1, -1):
if (
ssa_list[-1][1].name != ssa_list[i][1].name
and ssa_list[i][1].name[0 : self.prefix_length]
!= self.conditional_bp
):
assert (
self.active_block_list[-3].name[0 : self.prefix_length]
!= self.conditional_bp
)
# self.active_block_list[-3] is the second predecessor
phi.add_incoming(ssa_list[i][0], self.active_block_list[-3])
found_second_block = True
break
if not found_second_block:
raise RuntimeError("Second block missing in phi generation")
self.set_ssa_vars(creg, phi, False)
[docs]
def conv_conditional(self, command: Command, op: Conditional) -> None: # noqa: PLR0915, PLR0912
condition_name = command.args[0].reg_name
entry_point = self.module.module.entry_point
condb = pyqir.BasicBlock(
self.module.module.context,
f"{self.conditional_bp}{self.block_count}",
entry_point,
)
contb = pyqir.BasicBlock(
self.module.module.context,
f"{self.continue_bp}{self.block_count}",
entry_point,
)
self.block_count = self.block_count + 1
self.active_block_list.append(condb)
self.active_block_list.append(contb)
inner_op = op.op
if inner_op.type == OpType.CircBox:
assert isinstance(inner_op, CircBox)
conditional_circuit = self._decompose_conditional_circ_box(
inner_op,
command.args[op.width :],
)
condition_name = command.args[0].reg_name
if op.width == 1: # only one conditional bit
condition_bit_index = command.args[0].index[0]
ssa_bool = self._get_bit_from_creg(condition_name, condition_bit_index)
self.list_of_changed_cregs = []
self.active_block = condb
self.active_block_main = condb
if op.value == 1:
self.module.module.builder.condbr(ssa_bool, condb, contb)
self.module.module.builder.insert_at_end(condb)
self.subcircuit_to_module(conditional_circuit)
if op.value == 0:
self.module.module.builder.condbr(ssa_bool, contb, condb)
self.module.module.builder.insert_at_end(condb)
self.subcircuit_to_module(conditional_circuit)
self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)
self.active_block = contb
self.active_block_main = contb
self._add_phi()
else:
for i in range(op.width):
if command.args[i].reg_name != condition_name:
raise ValueError(
"conditional can only work with one entire register",
)
for i in range(op.width - 1):
if command.args[i].index[0] >= command.args[i + 1].index[0]:
raise ValueError(
"conditional can only work with one entire register",
)
if self.circuit.get_c_register(condition_name).size != op.width:
raise ValueError(
"conditional can only work with one entire register",
)
ssa_bool = self.module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.value),
self.get_ssa_vars(condition_name),
)
self.module.module.builder.condbr(ssa_bool, condb, contb)
self.module.module.builder.insert_at_end(condb)
self.list_of_changed_cregs = []
self.active_block = condb
self.active_block_main = condb
self.subcircuit_to_module(conditional_circuit)
self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)
self.active_block = contb
self.active_block_main = contb
self._add_phi()
else:
condition_name = command.args[0].reg_name
if op.width == 1: # only one conditional bit
condition_bit_index = command.args[0].index[0]
ssa_bool = self._get_bit_from_creg(condition_name, condition_bit_index)
self.list_of_changed_cregs = []
self.active_block = condb
self.active_block_main = condb
if op.value == 1:
self.module.module.builder.condbr(ssa_bool, condb, contb)
self.module.module.builder.insert_at_end(condb)
self.command_to_module(op.op, command.args[op.width :])
if op.value == 0:
self.module.module.builder.condbr(ssa_bool, contb, condb)
self.module.module.builder.insert_at_end(condb)
self.command_to_module(op.op, command.args[op.width :])
self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)
self.active_block = contb
self.active_block_main = contb
self._add_phi()
else:
for i in range(op.width):
if command.args[i].reg_name != condition_name:
raise ValueError(
"conditional can only work with one entire register",
)
for i in range(op.width - 1):
if command.args[i].index[0] >= command.args[i + 1].index[0]:
raise ValueError(
"conditional can only work with one entire register",
)
if self.circuit.get_c_register(condition_name).size != op.width:
raise ValueError(
"conditional can only work with one entire register",
)
ssa_bool = self.module.module.builder.icmp(
pyqir.IntPredicate.EQ,
pyqir.const(self.qir_int_type, op.value),
self.get_ssa_vars(condition_name),
)
self.module.module.builder.condbr(ssa_bool, condb, contb)
self.module.module.builder.insert_at_end(condb)
self.list_of_changed_cregs = []
self.active_block = condb
self.active_block_main = condb
self.command_to_module(op.op, command.args[op.width :])
self.module.module.builder.br(contb)
self.module.module.builder.insert_at_end(contb)
self.active_block = contb
self.active_block_main = contb
self._add_phi()
[docs]
def conv_measure(self, bits: list[Bit], qubits: list[Qubit]) -> None:
assert len(bits) == 1
assert len(qubits) == 1
qubit_index = qubits[0].index[0]
self.module.qis.mz(
self.module.module.qubits[qubit_index],
self.module.module.results[qubit_index],
)
ssa_measureresult = self.module.builder.call(
self.read_bit_from_result,
[
self.module.module.results[qubit_index],
],
)
self._set_bit_in_creg(bits[0].reg_name, bits[0].index[0], ssa_measureresult)
[docs]
def record_output(self) -> None:
for creg in self.circuit.c_registers:
reg_name = creg[0].reg_name
self.module.builder.call(
self.record_output_i64,
[
self._get_i64_ssa_reg(reg_name),
self.reg_const[reg_name],
],
)