Source code for pytket.qir.conversion.pytketqirgenerator

# Copyright Quantinuum
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast

import pyqir
from pyqir import Value

from pytket.circuit import (
    Bit,
    BitRegister,
    CircBox,
    Circuit,
    Command,
    Conditional,
    OpType,
    Qubit,
)

from .module import tketqirModule
from .qirgenerator import (
    AbstractQirGenerator,
)


[docs] class PytketQirGenerator(AbstractQirGenerator): """Generates QIR from a pytket circuit in line with the pytket profile. This profile uses the functions `get_creg_bit`, `set_creg_bit`, `set_creg_to_int`, `create_creg`, `get_int_from_creg` and `mz_to_creg_bit` for the handling of the classical registers. The other aspects of the QIR file are identical to the adaptive profile. """ def __init__( self, circuit: Circuit, module: tketqirModule, wasm_int_type: int, qir_int_type: int, ) -> None: super().__init__(circuit, module, wasm_int_type, qir_int_type) self.set_cregs: dict[str, list] = {} # Keep track of set registers. self.ssa_vars: dict[str, Value] = {} # Keep track of set ssa variables. # i1 get_creg_bit(i1* creg, i64 index) self.get_creg_bit = self.module.module.add_external_function( "get_creg_bit", pyqir.FunctionType( pyqir.IntType(self.module.module.context, 1), [self.qir_i1p_type, self.qir_int_type], ), ) # void set_creg_bit(i1* creg, i64 index, i1 value) self.set_creg_bit = self.module.module.add_external_function( "set_creg_bit", pyqir.FunctionType( pyqir.Type.void(self.module.module.context), [ self.qir_i1p_type, self.qir_int_type, pyqir.IntType(self.module.module.context, 1), ], ), ) # void set_creg_to_int(i1* creg, i64 value) self.set_creg_to_int = self.module.module.add_external_function( "set_creg_to_int", pyqir.FunctionType( pyqir.Type.void(self.module.module.context), [ self.qir_i1p_type, self.qir_int_type, ], ), ) # i1* create_creg(i64 size) self.create_creg = self.module.module.add_external_function( "create_creg", pyqir.FunctionType( self.qir_i1p_type, [pyqir.IntType(self.module.module.context, qir_int_type)], ), ) # i64 get_int_from_creg(i1* creg) self.get_int_from_creg = self.module.module.add_external_function( "get_int_from_creg", pyqir.FunctionType( self.qir_int_type, [ self.qir_i1p_type, ], ), ) # void mz_to_creg_bit(qubit, i1* creg, int creg_index) # measures one qubit to one bit entry in a creg self.mz_to_creg_bit = self.module.module.add_external_function( "mz_to_creg_bit", pyqir.FunctionType( pyqir.Type.void(self.module.module.context), [ pyqir.qubit_type(self.module.module.context), self.qir_i1p_type, self.qir_int_type, ], ), ) for creg in self.circuit.c_registers: reg_name = creg[0].reg_name self.reg_const[reg_name] = self.module.module.add_byte_string( str.encode(reg_name), ) entry = self.module.module.entry_block self.module.module.builder.insert_at_end(entry) for creg in self.circuit.c_registers: self._reg2ssa_var(creg, qir_int_type)
[docs] def get_ssa_vars(self, reg_name: str) -> Value: if reg_name not in self.ssa_vars: raise ValueError(f"{reg_name} is not a valid register") return self.ssa_vars[reg_name]
def _get_i64_ssa_reg(self, name: str) -> Value: ssa_var = self.module.builder.call( self.get_int_from_creg, [self.get_ssa_vars(name)], ) return ssa_var # noqa: RET504
[docs] def set_ssa_vars(self, reg_name: str, ssa_i64: Value, trunc: bool) -> None: self.module.builder.call( self.set_creg_to_int, [self.get_ssa_vars(reg_name), ssa_i64], )
def _set_bit_in_creg(self, creg: str, index: int, ssa_bit: Value) -> None: self.module.builder.call( self.set_creg_bit, [ self.get_ssa_vars(creg), pyqir.const(self.qir_int_type, index), ssa_bit, ], ) def _get_bit_from_creg(self, creg: str, index: int) -> Value: return self.module.builder.call( self.get_creg_bit, [ self.get_ssa_vars(creg), pyqir.const(self.qir_int_type, index), ], ) def _reg2ssa_var(self, bit_reg: BitRegister, int_size: int) -> Value: """Convert a BitRegister to an SSA variable using pyqir types.""" reg_name = bit_reg[0].reg_name if reg_name not in self.ssa_vars: if len(bit_reg) > int_size: raise ValueError( f"Classical register should only have the size of {int_size}", ) ssa_var = self.module.builder.call( self.create_creg, [pyqir.const(self.qir_int_type, len(bit_reg))], ) self.ssa_vars[reg_name] = ssa_var return ssa_var return cast("Value", self.ssa_vars[reg_name]) # type: ignore
[docs] def conv_conditional(self, command: Command, op: Conditional) -> None: # noqa: PLR0915, PLR0912 condition_name = command.args[0].reg_name entry_point = self.module.module.entry_point condb = pyqir.BasicBlock( self.module.module.context, f"condb{self.block_count}", entry_point, ) contb = pyqir.BasicBlock( self.module.module.context, f"contb{self.block_count}", entry_point, ) self.block_count = self.block_count + 1 if op.op.type == OpType.CircBox: conditional_circuit = self._decompose_conditional_circ_box( cast("CircBox", op.op), command.args[op.width :], ) condition_name = command.args[0].reg_name if op.width == 1: # only one conditional bit condition_bit_index = command.args[0].index[0] ssabool = self.module.builder.call( self.get_creg_bit, [ self.get_ssa_vars(condition_name), pyqir.const(self.qir_int_type, condition_bit_index), ], ) if op.value == 1: self.module.module.builder.condbr(ssabool, condb, contb) self.module.module.builder.insert_at_end(condb) self.subcircuit_to_module(conditional_circuit) if op.value == 0: self.module.module.builder.condbr(ssabool, contb, condb) self.module.module.builder.insert_at_end(condb) self.subcircuit_to_module(conditional_circuit) self.module.module.builder.br(contb) self.module.module.builder.insert_at_end(contb) else: for i in range(op.width): if command.args[i].reg_name != condition_name: raise ValueError( "conditional can only work with one entire register", ) for i in range(op.width - 1): if command.args[i].index[0] >= command.args[i + 1].index[0]: raise ValueError( "conditional can only work with one entire register", ) if self.circuit.get_c_register(condition_name).size != op.width: raise ValueError( "conditional can only work with one entire register", ) ssabool = self.module.module.builder.icmp( pyqir.IntPredicate.EQ, pyqir.const(self.qir_int_type, op.value), self._get_i64_ssa_reg(condition_name), ) self.module.module.builder.condbr(ssabool, condb, contb) self.module.module.builder.insert_at_end(condb) self.subcircuit_to_module(conditional_circuit) self.module.module.builder.br(contb) self.module.module.builder.insert_at_end(contb) else: condition_name = command.args[0].reg_name if op.width == 1: # only one conditional bit condition_bit_index = command.args[0].index[0] ssabool = self.module.builder.call( self.get_creg_bit, [ self.get_ssa_vars(condition_name), pyqir.const(self.qir_int_type, condition_bit_index), ], ) if op.value == 1: self.module.module.builder.condbr(ssabool, condb, contb) self.module.module.builder.insert_at_end(condb) self.command_to_module(op.op, command.args[op.width :]) if op.value == 0: self.module.module.builder.condbr(ssabool, contb, condb) self.module.module.builder.insert_at_end(condb) self.command_to_module(op.op, command.args[op.width :]) self.module.module.builder.br(contb) self.module.module.builder.insert_at_end(contb) else: for i in range(op.width): if command.args[i].reg_name != condition_name: raise ValueError( "conditional can only work with one entire register", ) for i in range(op.width - 1): if command.args[i].index[0] >= command.args[i + 1].index[0]: raise ValueError( "conditional can only work with one entire register", ) if self.circuit.get_c_register(condition_name).size != op.width: raise ValueError( "conditional can only work with one entire register", ) ssabool = self.module.module.builder.icmp( pyqir.IntPredicate.EQ, pyqir.const(self.qir_int_type, op.value), self._get_i64_ssa_reg(condition_name), ) self.module.module.builder.condbr(ssabool, condb, contb) self.module.module.builder.insert_at_end(condb) self.command_to_module(op.op, command.args[op.width :]) self.module.module.builder.br(contb) self.module.module.builder.insert_at_end(contb)
[docs] def conv_measure(self, bits: list[Bit], qubits: list[Qubit]) -> None: assert len(bits) == 1 assert len(qubits) == 1 self.module.builder.call( self.mz_to_creg_bit, [ self.module.module.qubits[qubits[0].index[0]], self.get_ssa_vars(bits[0].reg_name), pyqir.const(self.qir_int_type, bits[0].index[0]), ], )
[docs] def record_output(self) -> None: for creg in self.circuit.c_registers: reg_name = creg[0].reg_name self.module.builder.call( self.record_output_i64, [ self._get_i64_ssa_reg(reg_name), self.reg_const[reg_name], ], )