Source code for pytket.qasm.qasm

# Copyright 2019-2024 Cambridge Quantum Computing
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import os
import re
import uuid
from collections import OrderedDict
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from dataclasses import dataclass
from decimal import Decimal
from importlib import import_module
from itertools import chain, groupby
from typing import Any, NewType, TextIO, TypeVar, Union, cast

from lark import Discard, Lark, Token, Transformer, Tree
from sympy import Expr, Symbol, pi

from pytket._tket.circuit import (
    BarrierOp,
    ClassicalExpBox,
    ClExpr,
    ClExprOp,
    Command,
    Conditional,
    CopyBitsOp,
    MultiBitOp,
    RangePredicateOp,
    SetBitsOp,
    WASMOp,
    WiredClExpr,
)
from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE
from pytket.circuit import (
    Bit,
    BitRegister,
    Circuit,
    Op,
    OpType,
    Qubit,
    QubitRegister,
    UnitID,
)
from pytket.circuit.clexpr import (
    check_register_alignments,
    has_reg_output,
    wired_clexpr_from_logic_exp,
)
from pytket.circuit.decompose_classical import int_to_bools
from pytket.circuit.logic_exp import (
    BitLogicExp,
    BitWiseOp,
    LogicExp,
    PredicateExp,
    RegEq,
    RegLogicExp,
    RegNeg,
    RegWiseOp,
    create_logic_exp,
    create_predicate_exp,
)
from pytket.passes import (
    AutoRebase,
    DecomposeBoxes,
    RemoveRedundancies,
    scratch_reg_resize_pass,
)
from pytket.qasm.grammar import grammar
from pytket.wasm import WasmFileHandler


class QASMParseError(Exception):
    """Error while parsing QASM input."""

    def __init__(self, msg: str, line: int | None = None, fname: str | None = None):
        self.msg = msg
        self.line = line
        self.fname = fname

        ctx = "" if fname is None else f"\nFile:{fname}: "
        ctx += "" if line is None else f"\nLine:{line}. "

        super().__init__(f"{msg}{ctx}")


class QASMUnsupportedError(Exception):
    pass


Value = Union[int, float, str]
T = TypeVar("T")

_BITOPS = set(op.value for op in BitWiseOp)
_BITOPS.update(("+", "-"))  # both are parsed to XOR
_REGOPS = set(op.value for op in RegWiseOp)

Arg = Union[list, str]


NOPARAM_COMMANDS = {
    "CX": OpType.CX,  # built-in gate equivalent to "cx"
    "cx": OpType.CX,
    "x": OpType.X,
    "y": OpType.Y,
    "z": OpType.Z,
    "h": OpType.H,
    "s": OpType.S,
    "sdg": OpType.Sdg,
    "t": OpType.T,
    "tdg": OpType.Tdg,
    "sx": OpType.SX,
    "sxdg": OpType.SXdg,
    "cz": OpType.CZ,
    "cy": OpType.CY,
    "ch": OpType.CH,
    "csx": OpType.CSX,
    "ccx": OpType.CCX,
    "c3x": OpType.CnX,
    "c4x": OpType.CnX,
    "ZZ": OpType.ZZMax,
    "measure": OpType.Measure,
    "reset": OpType.Reset,
    "id": OpType.noop,
    "barrier": OpType.Barrier,
    "swap": OpType.SWAP,
    "cswap": OpType.CSWAP,
}
PARAM_COMMANDS = {
    "p": OpType.U1,  # alias. https://github.com/Qiskit/qiskit-terra/pull/4765
    "u": OpType.U3,  # alias. https://github.com/Qiskit/qiskit-terra/pull/4765
    "U": OpType.U3,  # built-in gate equivalent to "u3"
    "u3": OpType.U3,
    "u2": OpType.U2,
    "u1": OpType.U1,
    "rx": OpType.Rx,
    "rxx": OpType.XXPhase,
    "ry": OpType.Ry,
    "rz": OpType.Rz,
    "RZZ": OpType.ZZPhase,
    "rzz": OpType.ZZPhase,
    "Rz": OpType.Rz,
    "U1q": OpType.PhasedX,
    "crz": OpType.CRz,
    "crx": OpType.CRx,
    "cry": OpType.CRy,
    "cu1": OpType.CU1,
    "cu3": OpType.CU3,
    "Rxxyyzz": OpType.TK2,
}

NOPARAM_EXTRA_COMMANDS = {
    "v": OpType.V,
    "vdg": OpType.Vdg,
    "cv": OpType.CV,
    "cvdg": OpType.CVdg,
    "csxdg": OpType.CSXdg,
    "bridge": OpType.BRIDGE,
    "iswapmax": OpType.ISWAPMax,
    "zzmax": OpType.ZZMax,
    "ecr": OpType.ECR,
    "cs": OpType.CS,
    "csdg": OpType.CSdg,
}

PARAM_EXTRA_COMMANDS = {
    "tk2": OpType.TK2,
    "iswap": OpType.ISWAP,
    "phasediswap": OpType.PhasedISWAP,
    "yyphase": OpType.YYPhase,
    "xxphase3": OpType.XXPhase3,
    "eswap": OpType.ESWAP,
    "fsim": OpType.FSim,
}

_tk_to_qasm_noparams = dict((item[1], item[0]) for item in NOPARAM_COMMANDS.items())
_tk_to_qasm_noparams[OpType.CX] = "cx"  # prefer "cx" to "CX"
_tk_to_qasm_params = dict((item[1], item[0]) for item in PARAM_COMMANDS.items())
_tk_to_qasm_params[OpType.U3] = "u3"  # prefer "u3" to "U"
_tk_to_qasm_params[OpType.Rz] = "rz"  # prefer "rz" to "Rz"
_tk_to_qasm_extra_noparams = dict(
    (item[1], item[0]) for item in NOPARAM_EXTRA_COMMANDS.items()
)
_tk_to_qasm_extra_params = dict(
    (item[1], item[0]) for item in PARAM_EXTRA_COMMANDS.items()
)

_classical_gatestr_map = {"AND": "&", "OR": "|", "XOR": "^"}


_all_known_gates = (
    set(NOPARAM_COMMANDS.keys())
    .union(PARAM_COMMANDS.keys())
    .union(PARAM_EXTRA_COMMANDS.keys())
    .union(NOPARAM_EXTRA_COMMANDS.keys())
)
_all_string_maps = {
    key: val.name
    for key, val in chain(
        PARAM_COMMANDS.items(),
        NOPARAM_COMMANDS.items(),
        PARAM_EXTRA_COMMANDS.items(),
        NOPARAM_EXTRA_COMMANDS.items(),
    )
}

unit_regex = re.compile(r"([a-z][a-zA-Z0-9_]*)\[([\d]+)\]")
regname_regex = re.compile(r"^[a-z][a-zA-Z0-9_]*$")


def _extract_reg(var: Token) -> tuple[str, int]:
    match = unit_regex.match(var.value)
    if match is None:
        raise QASMParseError(
            f"Invalid register definition '{var.value}'. Register definitions "
            "must follow the pattern '<name> [<size in integer>]'. "
            "For example, 'q [5]'. QASM register names must begin with a "
            "lowercase letter and may only contain lowercase and uppercase "
            "letters, numbers, and underscores."
        )
    return match.group(1), int(match.group(2))


def _load_include_module(
    header_name: str, flter: bool, decls_only: bool
) -> dict[str, dict]:
    try:
        if decls_only:
            include_def: dict[str, dict] = import_module(
                f"pytket.qasm.includes._{header_name}_decls"
            )._INCLUDE_DECLS
        else:
            include_def = import_module(
                f"pytket.qasm.includes._{header_name}_defs"
            )._INCLUDE_DEFS
    except ModuleNotFoundError as e:
        raise QASMParseError(
            f"Header {header_name} is not known and cannot be loaded."
        ) from e
    return {
        gate: include_def[gate]
        for gate in include_def
        if not flter or gate not in _all_known_gates
    }


def _bin_par_exp(op: "str") -> Callable[["CircuitTransformer", list[str]], str]:
    def f(self: "CircuitTransformer", vals: list[str]) -> str:
        return f"({vals[0]} {op} {vals[1]})"

    return f


def _un_par_exp(op: "str") -> Callable[["CircuitTransformer", list[str]], str]:
    def f(self: "CircuitTransformer", vals: list[str]) -> str:
        return f"({op}{vals[0]})"

    return f


def _un_call_exp(op: "str") -> Callable[["CircuitTransformer", list[str]], str]:
    def f(self: "CircuitTransformer", vals: list[str]) -> str:
        return f"{op}({vals[0]})"

    return f


def _hashable_uid(arg: list) -> tuple[str, int]:
    return arg[0], arg[1][0]


Reg = NewType("Reg", str)
CommandDict = dict[str, Any]


@dataclass
class ParsMap:
    pars: Iterable[str]

    def __iter__(self) -> Iterable[str]:
        return self.pars


class CircuitTransformer(Transformer):
    def __init__(
        self,
        return_gate_dict: bool = False,
        maxwidth: int = 32,
        use_clexpr: bool = True,
    ) -> None:
        super().__init__()
        self.q_registers: dict[str, int] = {}
        self.c_registers: dict[str, int] = {}
        self.gate_dict: dict[str, dict] = {}
        self.wasm: WasmFileHandler | None = None
        self.include = ""
        self.return_gate_dict = return_gate_dict
        self.maxwidth = maxwidth
        self.use_clexpr = use_clexpr

    def _fresh_temp_bit(self) -> list:
        if _TEMP_BIT_NAME in self.c_registers:
            idx = self.c_registers[_TEMP_BIT_NAME]
        else:
            idx = 0
        self.c_registers[_TEMP_BIT_NAME] = idx + 1

        return [_TEMP_BIT_NAME, [idx]]

    def _reset_context(self, reset_wasm: bool = True) -> None:
        self.q_registers = {}
        self.c_registers = {}
        self.gate_dict = {}
        self.include = ""
        if reset_wasm:
            self.wasm = None

    def _get_reg(self, name: str) -> Reg:
        return Reg(name)

    def _get_uid(self, iarg: Token) -> list:
        name, idx = _extract_reg(iarg)
        return [name, [idx]]

    def _get_arg(self, arg: Token) -> Arg:
        if arg.type == "IARG":
            return self._get_uid(arg)
        return self._get_reg(arg.value)

    def unroll_all_args(self, args: Iterable[Arg]) -> Iterator[list[Any]]:
        for arg in args:
            if isinstance(arg, str):
                size = (
                    self.q_registers[arg]
                    if arg in self.q_registers
                    else self.c_registers[arg]
                )
                yield [[arg, [idx]] for idx in range(size)]
            else:
                yield [arg]

    def margs(self, tree: Iterable[Token]) -> Iterator[Arg]:
        return map(self._get_arg, tree)

    def iargs(self, tree: Iterable[Token]) -> Iterator[list]:
        return map(self._get_uid, tree)

    def args(self, tree: Iterable[Token]) -> Iterator[list]:
        return ([tok.value, [0]] for tok in tree)

    def creg(self, tree: list[Token]) -> None:
        name, size = _extract_reg(tree[0])
        if size > self.maxwidth:
            raise QASMUnsupportedError(
                f"Circuit contains classical register {name} of size {size} > "
                f"{self.maxwidth}: try setting the `maxwidth` parameter to a larger "
                "value."
            )
        self.c_registers[Reg(name)] = size

    def qreg(self, tree: list[Token]) -> None:
        name, size = _extract_reg(tree[0])
        self.q_registers[Reg(name)] = size

    def meas(self, tree: list[Token]) -> Iterable[CommandDict]:
        for args in zip(*self.unroll_all_args(self.margs(tree))):
            yield {"args": list(args), "op": {"type": "Measure"}}

    def barr(self, tree: list[Arg]) -> Iterable[CommandDict]:
        args = [q for qs in self.unroll_all_args(tree[0]) for q in qs]
        signature: list[str] = []
        for arg in args:
            if arg[0] in self.c_registers:
                signature.append("C")
            elif arg[0] in self.q_registers:
                signature.append("Q")
            else:
                raise QASMParseError(
                    "UnitID " + str(arg) + " in Barrier arguments is not declared."
                )
        yield {
            "args": args,
            "op": {"signature": signature, "type": "Barrier"},
        }

    def reset(self, tree: list[Token]) -> Iterable[CommandDict]:
        for qb in next(self.unroll_all_args(self.margs(tree))):
            yield {"args": [qb], "op": {"type": "Reset"}}

    def pars(self, vals: Iterable[str]) -> ParsMap:
        return ParsMap(map(str, vals))

    def mixedcall(self, tree: list) -> Iterator[CommandDict]:
        child_iter = iter(tree)

        optoken = next(child_iter)
        opstr = optoken.value
        next_tree = next(child_iter)
        try:
            args = next(child_iter)
            pars = cast(ParsMap, next_tree).pars
        except StopIteration:
            args = next_tree
            pars = []

        treat_as_barrier = [
            "sleep",
            "order2",
            "order3",
            "order4",
            "order5",
            "order6",
            "order7",
            "order8",
            "order9",
            "order10",
            "order11",
            "order12",
            "order13",
            "order14",
            "order15",
            "order16",
            "order17",
            "order18",
            "order19",
            "order20",
            "group2",
            "group3",
            "group4",
            "group5",
            "group6",
            "group7",
            "group8",
            "group9",
            "group10",
            "group11",
            "group12",
            "group13",
            "group14",
            "group15",
            "group16",
            "group17",
            "group18",
            "group19",
            "group20",
        ]
        # other opaque gates, which are not handled as barrier
        # ["RZZ", "Rxxyyzz", "Rxxyyzz_zphase", "cu", "cp", "rccx", "rc3x", "c3sqrtx"]

        args = list(args)

        if opstr in treat_as_barrier:
            params = [f"{par}" for par in pars]
        else:
            params = [f"({par})/pi" for par in pars]
        if opstr in self.gate_dict:
            op: dict[str, Any] = {}
            if opstr in treat_as_barrier:
                op["type"] = "Barrier"
                param_sorted = ",".join(params)

                op["data"] = f"{opstr}({param_sorted})"

                op["signature"] = [arg[0] for arg in args]
            else:
                gdef = self.gate_dict[opstr]
                op["type"] = "CustomGate"
                box = {
                    "type": "CustomGate",
                    "id": str(uuid.uuid4()),
                    "gate": gdef,
                }
                box["params"] = params
                op["box"] = box
                params = []  # to stop duplication in to op
        else:
            try:
                optype = _all_string_maps[opstr]
            except KeyError as e:
                raise QASMParseError(
                    f"Cannot parse gate of type: {opstr}", optoken.line
                ) from e
            op = {"type": optype}
            if params:
                op["params"] = params
            # Operations needing special handling:
            if optype.startswith("Cn"):
                # n-controlled rotations have variable signature
                op["n_qb"] = len(args)
            elif optype == "Barrier":
                op["signature"] = ["Q"] * len(args)

        for arg in zip(*self.unroll_all_args(args)):
            yield {"args": list(arg), "op": op}

    def gatecall(self, tree: list) -> Iterable[CommandDict]:
        return self.mixedcall(tree)

    def exp_args(self, tree: Iterable[Token]) -> Iterable[Reg]:
        for arg in tree:
            if arg.type == "ARG":
                yield self._get_reg(arg.value)
            else:
                raise QASMParseError(
                    "Non register arguments not supported for extern call.", arg.line
                )

    def _logic_exp(self, tree: list, opstr: str) -> LogicExp:
        args, line = self._get_logic_args(tree)
        openum: type[BitWiseOp] | type[RegWiseOp]
        if opstr in _BITOPS and opstr not in _REGOPS:
            openum = BitWiseOp
        elif (
            opstr in _REGOPS
            and opstr not in _BITOPS
            or all(isinstance(arg, int) for arg in args)
        ):
            openum = RegWiseOp
        elif all(isinstance(arg, (Bit, BitLogicExp, int)) for arg in args):
            if all(arg in (0, 1) for arg in args if isinstance(arg, int)):
                openum = BitWiseOp
            else:
                raise QASMParseError(
                    "Bits can only be operated with (0, 1) literals."
                    f" Incomaptible arguments {args}",
                    line,
                )
        else:
            openum = RegWiseOp
        if openum is BitWiseOp and opstr in ("+", "-"):
            op: BitWiseOp | RegWiseOp = BitWiseOp.XOR
        else:
            op = openum(opstr)
        return create_logic_exp(op, args)

    def _get_logic_args(
        self, tree: Sequence[Token | LogicExp]
    ) -> tuple[list[LogicExp | Bit | BitRegister | int], int | None]:
        args: list[LogicExp | Bit | BitRegister | int] = []
        line = None
        for tok in tree:
            if isinstance(tok, LogicExp):
                args.append(tok)
            elif isinstance(tok, Token):
                line = tok.line
                if tok.type == "INT":
                    args.append(int(tok.value))
                elif tok.type == "IARG":
                    args.append(Bit(*_extract_reg(tok)))
                elif tok.type == "ARG":
                    args.append(BitRegister(tok.value, self.c_registers[tok.value]))
                else:
                    raise QASMParseError(f"Could not pass argument {tok}")
            else:
                raise QASMParseError(f"Could not pass argument {tok}")
        return args, line

    par_add = _bin_par_exp("+")
    par_sub = _bin_par_exp("-")
    par_mul = _bin_par_exp("*")
    par_div = _bin_par_exp("/")
    par_pow = _bin_par_exp("**")

    par_neg = _un_par_exp("-")

    sqrt = _un_call_exp("sqrt")
    sin = _un_call_exp("sin")
    cos = _un_call_exp("cos")
    tan = _un_call_exp("tan")
    ln = _un_call_exp("ln")

    b_and = lambda self, tree: self._logic_exp(tree, "&")
    b_not = lambda self, tree: self._logic_exp(tree, "~")
    b_or = lambda self, tree: self._logic_exp(tree, "|")
    xor = lambda self, tree: self._logic_exp(tree, "^")
    lshift = lambda self, tree: self._logic_exp(tree, "<<")
    rshift = lambda self, tree: self._logic_exp(tree, ">>")
    add = lambda self, tree: self._logic_exp(tree, "+")
    sub = lambda self, tree: self._logic_exp(tree, "-")
    mul = lambda self, tree: self._logic_exp(tree, "*")
    div = lambda self, tree: self._logic_exp(tree, "/")
    ipow = lambda self, tree: self._logic_exp(tree, "**")

    def neg(self, tree: list[Token | LogicExp]) -> RegNeg:
        arg = self._get_logic_args(tree)[0][0]
        assert isinstance(arg, (RegLogicExp, BitRegister, int))
        return RegNeg(arg)

    def cond(self, tree: list[Token]) -> PredicateExp:
        op: BitWiseOp | RegWiseOp
        arg: Bit | BitRegister
        if tree[1].type == "IARG":
            arg = Bit(*_extract_reg(tree[1]))
            op = BitWiseOp(str(tree[2]))
        else:
            arg = BitRegister(tree[1].value, self.c_registers[tree[1].value])
            op = RegWiseOp(str(tree[2]))

        return create_predicate_exp(op, [arg, int(tree[3].value)])

    def ifc(self, tree: Sequence) -> Iterable[CommandDict]:
        condition = cast(PredicateExp, tree[0])

        var, val = condition.args
        condition_bits = []

        if isinstance(var, Bit):
            assert condition.op in (BitWiseOp.EQ, BitWiseOp.NEQ)
            assert val in (0, 1)
            condition_bits = [var.to_list()]

        else:
            assert isinstance(var, BitRegister)
            reg_bits = next(self.unroll_all_args([var.name]))
            if isinstance(condition, RegEq):
                # special case for base qasm
                condition_bits = reg_bits
            else:
                pred_val = cast(int, val)
                minval = 0
                maxval = (1 << self.maxwidth) - 1
                if condition.op == RegWiseOp.LT:
                    maxval = pred_val - 1
                elif condition.op == RegWiseOp.GT:
                    minval = pred_val + 1
                if condition.op in (RegWiseOp.LEQ, RegWiseOp.EQ, RegWiseOp.NEQ):
                    maxval = pred_val
                if condition.op in (RegWiseOp.GEQ, RegWiseOp.EQ, RegWiseOp.NEQ):
                    minval = pred_val

                condition_bit = self._fresh_temp_bit()
                yield {
                    "args": reg_bits + [condition_bit],
                    "op": {
                        "classical": {
                            "lower": minval,
                            "n_i": len(reg_bits),
                            "upper": maxval,
                        },
                        "type": "RangePredicate",
                    },
                }
                condition_bits = [condition_bit]
                val = int(condition.op != RegWiseOp.NEQ)

        for com in filter(lambda x: x is not None and x is not Discard, tree[1]):
            com["args"] = condition_bits + com["args"]
            com["op"] = {
                "conditional": {
                    "op": com["op"],
                    "value": val,
                    "width": len(condition_bits),
                },
                "type": "Conditional",
            }

            yield com

    def cop(self, tree: Sequence[Iterable[CommandDict]]) -> Iterable[CommandDict]:
        return tree[0]

    def _calc_exp_io(
        self, exp: LogicExp, out_args: list
    ) -> tuple[list[list], dict[str, Any]]:
        all_inps: list[tuple[str, int]] = []
        for inp in exp.all_inputs_ordered():
            if isinstance(inp, Bit):
                all_inps.append((inp.reg_name, inp.index[0]))
            else:
                assert isinstance(inp, BitRegister)
                for bit in inp:
                    all_inps.append((bit.reg_name, bit.index[0]))
        outs = (_hashable_uid(arg) for arg in out_args)
        o = []
        io = []
        for out in outs:
            if out in all_inps:
                all_inps.remove(out)
                io.append(out)
            else:
                o.append(out)

        exp_args = list(
            map(lambda x: [x[0], [x[1]]], chain.from_iterable((all_inps, io, o)))
        )
        numbers_dict = {
            "n_i": len(all_inps),
            "n_io": len(io),
            "n_o": len(o),
        }
        return exp_args, numbers_dict

    def _cexpbox_dict(self, exp: LogicExp, args: list[list]) -> CommandDict:
        box = {
            "exp": exp.to_dict(),
            "id": str(uuid.uuid4()),
            "type": "ClassicalExpBox",
        }
        args, numbers = self._calc_exp_io(exp, args)
        box.update(numbers)
        return {
            "args": args,
            "op": {
                "box": box,
                "type": "ClassicalExpBox",
            },
        }

    def _clexpr_dict(self, exp: LogicExp, out_args: list[list]) -> CommandDict:
        # Convert the LogicExp to a serialization of a command containing the
        # corresponding ClExprOp.
        wexpr, args = wired_clexpr_from_logic_exp(
            exp, [Bit.from_list(arg) for arg in out_args]
        )
        return {
            "op": {
                "type": "ClExpr",
                "expr": wexpr.to_dict(),
            },
            "args": [arg.to_list() for arg in args],
        }

    def _logic_exp_as_cmd_dict(
        self, exp: LogicExp, out_args: list[list]
    ) -> CommandDict:
        return (
            self._clexpr_dict(exp, out_args)
            if self.use_clexpr
            else self._cexpbox_dict(exp, out_args)
        )

    def assign(self, tree: list) -> Iterable[CommandDict]:
        child_iter = iter(tree)
        out_args = list(next(child_iter))
        args_uids = list(self.unroll_all_args(out_args))

        exp_tree = next(child_iter)

        exp: str | list | LogicExp | int = ""
        line = None
        if isinstance(exp_tree, Token):
            if exp_tree.type == "INT":
                exp = int(exp_tree.value)
            elif exp_tree.type in ("ARG", "IARG"):
                exp = self._get_arg(exp_tree)
            line = exp_tree.line
        elif isinstance(exp_tree, Generator):
            # assume to be extern (wasm) call
            chained_uids = list(chain.from_iterable(args_uids))
            com = next(exp_tree)
            com["args"].pop()  # remove the wasmstate from the args
            com["args"] += chained_uids
            com["args"].append(["_w", [0]])
            com["op"]["wasm"]["n"] += len(chained_uids)
            com["op"]["wasm"]["width_o_parameter"] = [
                self.c_registers[reg] for reg in out_args
            ]

            yield com
            return
        else:
            exp = exp_tree

        assert len(out_args) == 1
        out_arg = out_args[0]
        args = args_uids[0]
        if isinstance(out_arg, list):
            if isinstance(exp, LogicExp):
                yield self._logic_exp_as_cmd_dict(exp, args)
            elif isinstance(exp, (int, bool)):
                assert exp in (0, 1, True, False)
                yield {
                    "args": args,
                    "op": {"classical": {"values": [bool(exp)]}, "type": "SetBits"},
                }
            elif isinstance(exp, list):
                yield {
                    "args": [exp] + args,
                    "op": {"classical": {"n_i": 1}, "type": "CopyBits"},
                }
            else:
                raise QASMParseError(f"Unexpected expression in assignment {exp}", line)
        else:
            reg = out_arg
            if isinstance(exp, RegLogicExp):
                yield self._logic_exp_as_cmd_dict(exp, args)
            elif isinstance(exp, BitLogicExp):
                yield self._logic_exp_as_cmd_dict(exp, args[:1])
            elif isinstance(exp, int):
                yield {
                    "args": args,
                    "op": {
                        "classical": {
                            "values": int_to_bools(exp, self.c_registers[reg])
                        },
                        "type": "SetBits",
                    },
                }

            elif isinstance(exp, str):
                width = min(self.c_registers[exp], len(args))
                yield {
                    "args": [[exp, [i]] for i in range(width)] + args[:width],
                    "op": {"classical": {"n_i": width}, "type": "CopyBits"},
                }

            else:
                raise QASMParseError(f"Unexpected expression in assignment {exp}", line)

    def extern(self, tree: list[Any]) -> Any:
        # TODO parse extern defs
        return Discard

    def ccall(self, tree: list) -> Iterable[CommandDict]:
        return self.cce_call(tree)

    def cce_call(self, tree: list) -> Iterable[CommandDict]:
        nam = tree[0].value
        params = list(tree[1])
        if self.wasm is None:
            raise QASMParseError(
                "Cannot include extern calls without a wasm module specified.",
                tree[0].line,
            )
        n_i_vec = [self.c_registers[reg] for reg in params]

        wasm_args = list(chain.from_iterable(self.unroll_all_args(params)))

        wasm_args.append(["_w", [0]])

        yield {
            "args": wasm_args,
            "op": {
                "type": "WASM",
                "wasm": {
                    "func_name": nam,
                    "ww_n": 1,
                    "n": sum(n_i_vec),
                    "width_i_parameter": n_i_vec,
                    "width_o_parameter": [],  # this will be set in the assign function
                    "wasm_file_uid": str(self.wasm),
                },
            },
        }

    def transform(self, tree: Tree) -> dict[str, Any]:
        self._reset_context()
        return cast(dict[str, Any], super().transform(tree))

    def gdef(self, tree: list) -> None:
        child_iter = iter(tree)

        gate = next(child_iter).value

        next_tree = next(child_iter)
        symbols, args = [], []
        if isinstance(next_tree, ParsMap):
            symbols = list(next_tree.pars)
            args = list(next(child_iter))
        else:
            args = list(next_tree)

        symbol_map = {sym: sym * pi for sym in map(Symbol, symbols)}
        rename_map = {Qubit.from_list(qb): Qubit("q", i) for i, qb in enumerate(args)}

        new = CircuitTransformer(maxwidth=self.maxwidth)
        circ_dict = new.prog(child_iter)

        circ_dict["qubits"] = args
        gate_circ = Circuit.from_dict(circ_dict)

        # check to see whether gate definition was generated by pytket converter
        # if true, add op as pytket Op
        existing_op: bool = False
        if gate in NOPARAM_EXTRA_COMMANDS:
            qubit_args = [
                Qubit(gate + "q" + str(index), 0) for index in list(range(len(args)))
            ]
            comparison_circ = _get_gate_circuit(
                NOPARAM_EXTRA_COMMANDS[gate], qubit_args
            )
            if circuit_to_qasm_str(
                comparison_circ, maxwidth=self.maxwidth
            ) == circuit_to_qasm_str(gate_circ, maxwidth=self.maxwidth):
                existing_op = True
        elif gate in PARAM_EXTRA_COMMANDS:
            qubit_args = [
                Qubit(gate + "q" + str(index), 0) for index in list(range(len(args)))
            ]
            comparison_circ = _get_gate_circuit(
                PARAM_EXTRA_COMMANDS[gate],
                qubit_args,
                [Symbol("param" + str(index) + "/pi") for index in range(len(symbols))],
            )
            # checks that each command has same string
            existing_op = all(
                str(g) == str(c)
                for g, c in zip(
                    gate_circ.get_commands(), comparison_circ.get_commands()
                )
            )
        if not existing_op:
            gate_circ.symbol_substitution(symbol_map)
            gate_circ.rename_units(cast(dict[UnitID, UnitID], rename_map))

            self.gate_dict[gate] = {
                "definition": gate_circ.to_dict(),
                "args": symbols,
                "name": gate,
            }

    opaq = gdef

    def oqasm(self, tree: list) -> Any:
        return Discard

    def incl(self, tree: list[Token]) -> None:
        self.include = str(tree[0].value).split(".")[0]
        self.gate_dict.update(_load_include_module(self.include, True, False))

    def prog(self, tree: Iterable) -> dict[str, Any]:
        outdict: dict[str, Any] = {
            "commands": list(
                chain.from_iterable(
                    filter(lambda x: x is not None and x is not Discard, tree)
                )
            )
        }
        if self.return_gate_dict:
            return self.gate_dict
        outdict["qubits"] = [
            [reg, [i]] for reg, size in self.q_registers.items() for i in range(size)
        ]
        outdict["bits"] = [
            [reg, [i]] for reg, size in self.c_registers.items() for i in range(size)
        ]
        outdict["implicit_permutation"] = [[q, q] for q in outdict["qubits"]]
        outdict["phase"] = "0.0"
        self._reset_context()
        return outdict


def parser(maxwidth: int, use_clexpr: bool) -> Lark:
    return Lark(
        grammar,
        start="prog",
        debug=False,
        parser="lalr",
        cache=True,
        transformer=CircuitTransformer(maxwidth=maxwidth, use_clexpr=use_clexpr),
    )


g_parser = None
g_maxwidth = 32
g_use_clexpr = False


def set_parser(maxwidth: int, use_clexpr: bool) -> None:
    global g_parser, g_maxwidth, g_use_clexpr
    if (g_parser is None) or (g_maxwidth != maxwidth) or g_use_clexpr != use_clexpr:  # type: ignore
        g_parser = parser(maxwidth=maxwidth, use_clexpr=use_clexpr)
        g_maxwidth = maxwidth
        g_use_clexpr = use_clexpr


[docs] def circuit_from_qasm( input_file: Union[str, "os.PathLike[Any]"], encoding: str = "utf-8", maxwidth: int = 32, use_clexpr: bool = True, ) -> Circuit: """A method to generate a tket Circuit from a qasm file. :param input_file: path to qasm file; filename must have ``.qasm`` extension :param encoding: file encoding (default utf-8) :param maxwidth: maximum allowed width of classical registers (default 32) :param use_clexpr: whether to use ClExprOp to represent classical expressions :return: pytket circuit """ ext = os.path.splitext(input_file)[-1] if ext != ".qasm": raise TypeError("Can only convert .qasm files") with open(input_file, encoding=encoding) as f: try: circ = circuit_from_qasm_io(f, maxwidth=maxwidth, use_clexpr=use_clexpr) except QASMParseError as e: raise QASMParseError(e.msg, e.line, str(input_file)) return circ
[docs] def circuit_from_qasm_str( qasm_str: str, maxwidth: int = 32, use_clexpr: bool = True ) -> Circuit: """A method to generate a tket Circuit from a qasm string. :param qasm_str: qasm string :param maxwidth: maximum allowed width of classical registers (default 32) :param use_clexpr: whether to use ClExprOp to represent classical expressions :return: pytket circuit """ global g_parser set_parser(maxwidth=maxwidth, use_clexpr=use_clexpr) assert g_parser is not None cast(CircuitTransformer, g_parser.options.transformer)._reset_context( reset_wasm=False ) circ = Circuit.from_dict(g_parser.parse(qasm_str)) # type: ignore[arg-type] cpass = scratch_reg_resize_pass(maxwidth) cpass.apply(circ) return circ
[docs] def circuit_from_qasm_io( stream_in: TextIO, maxwidth: int = 32, use_clexpr: bool = True ) -> Circuit: """A method to generate a tket Circuit from a qasm text stream""" return circuit_from_qasm_str( stream_in.read(), maxwidth=maxwidth, use_clexpr=use_clexpr )
[docs] def circuit_from_qasm_wasm( input_file: Union[str, "os.PathLike[Any]"], wasm_file: Union[str, "os.PathLike[Any]"], encoding: str = "utf-8", maxwidth: int = 32, use_clexpr: bool = True, ) -> Circuit: """A method to generate a tket Circuit from a qasm string and external WASM module. :param input_file: path to qasm file; filename must have ``.qasm`` extension :param wasm_file: path to WASM file containing functions used in qasm :param encoding: encoding of qasm file (default utf-8) :param maxwidth: maximum allowed width of classical registers (default 32) :return: pytket circuit """ global g_parser wasm_module = WasmFileHandler(str(wasm_file)) set_parser(maxwidth=maxwidth, use_clexpr=use_clexpr) assert g_parser is not None cast(CircuitTransformer, g_parser.options.transformer).wasm = wasm_module return circuit_from_qasm( input_file, encoding=encoding, maxwidth=maxwidth, use_clexpr=use_clexpr )
[docs] def circuit_to_qasm( circ: Circuit, output_file: str, header: str = "qelib1", maxwidth: int = 32 ) -> None: """Convert a Circuit to QASM and write it to a file. Classical bits in the pytket circuit must be singly-indexed. Note that this will not account for implicit qubit permutations in the Circuit. :param circ: pytket circuit :param output_file: path to output qasm file :param header: qasm header (default "qelib1") :param maxwidth: maximum allowed width of classical registers (default 32) """ with open(output_file, "w") as out: circuit_to_qasm_io(circ, out, header=header, maxwidth=maxwidth)
def _filtered_qasm_str(qasm: str) -> str: # remove any c registers starting with _TEMP_BIT_NAME # that are not being used somewhere else lines = qasm.split("\n") def_matcher = re.compile(rf"creg ({_TEMP_BIT_NAME}\_*\d*)\[\d+\]") arg_matcher = re.compile(rf"({_TEMP_BIT_NAME}\_*\d*)\[\d+\]") unused_regs = dict() for i, line in enumerate(lines): if reg := def_matcher.match(line): # Mark a reg temporarily as unused unused_regs[reg.group(1)] = i elif args := arg_matcher.findall(line): # If the line contains scratch bits that are used as arguments # mark these regs as used for arg in args: if arg in unused_regs: unused_regs.pop(arg) # remove unused reg defs redundant_lines = sorted(unused_regs.values(), reverse=True) for line_index in redundant_lines: del lines[line_index] return "\n".join(lines) def is_empty_customgate(op: Op) -> bool: return op.type == OpType.CustomGate and op.get_circuit().n_gates == 0 # type: ignore def check_can_convert_circuit(circ: Circuit, header: str, maxwidth: int) -> None: if any( circ.n_gates_of_type(typ) for typ in ( OpType.RangePredicate, OpType.MultiBit, OpType.ExplicitPredicate, OpType.ExplicitModifier, OpType.SetBits, OpType.CopyBits, OpType.ClassicalExpBox, ) ) and (not hqs_header(header)): raise QASMUnsupportedError( "Complex classical gates not supported with qelib1: try converting with " "`header=hqslib1`" ) if any(bit.index[0] >= maxwidth for bit in circ.bits): raise QASMUnsupportedError( f"Circuit contains a classical register larger than {maxwidth}: try " "setting the `maxwidth` parameter to a higher value." ) set_circ_register = set([creg.name for creg in circ.c_registers]) for b in circ.bits: if b.reg_name not in set_circ_register: raise QASMUnsupportedError( f"Circuit contains an invalid classical register {b.reg_name}." ) # Empty CustomGates should have been removed by DecomposeBoxes(). for cmd in circ: assert not is_empty_customgate(cmd.op) if isinstance(cmd.op, Conditional): assert not is_empty_customgate(cmd.op.op) if not check_register_alignments(circ): raise QASMUnsupportedError( "Circuit contains classical expressions on registers whose arguments or " "outputs are not register-aligned." )
[docs] def circuit_to_qasm_str( circ: Circuit, header: str = "qelib1", include_gate_defs: set[str] | None = None, maxwidth: int = 32, ) -> str: """Convert a Circuit to QASM and return the string. Classical bits in the pytket circuit must be singly-indexed. Note that this will not account for implicit qubit permutations in the Circuit. :param circ: pytket circuit :param header: qasm header (default "qelib1") :param output_file: path to output qasm file :param include_gate_defs: optional set of gates to include :param maxwidth: maximum allowed width of classical registers (default 32) :return: qasm string """ qasm_writer = QasmWriter( circ.qubits, circ.bits, header, include_gate_defs, maxwidth ) circ1 = circ.copy() DecomposeBoxes().apply(circ1) check_can_convert_circuit(circ1, header, maxwidth) for command in circ1: assert isinstance(command, Command) qasm_writer.add_op(command.op, command.args) return qasm_writer.finalize()
TypeReg = TypeVar("TypeReg", BitRegister, QubitRegister) def _retrieve_registers( units: list[UnitID], reg_type: type[TypeReg] ) -> dict[str, TypeReg]: if any(len(unit.index) != 1 for unit in units): raise NotImplementedError("OPENQASM registers must use a single index") maxunits = map(lambda x: max(x[1]), groupby(units, key=lambda un: un.reg_name)) return { maxunit.reg_name: reg_type(maxunit.reg_name, maxunit.index[0] + 1) for maxunit in maxunits } def _parse_range(minval: int, maxval: int, maxwidth: int) -> tuple[str, int]: if maxwidth > 64: raise NotImplementedError("Register width exceeds maximum of 64.") REGMAX = (1 << maxwidth) - 1 if minval > REGMAX: raise NotImplementedError("Range's lower bound exceeds register capacity.") if minval > maxval: raise NotImplementedError("Range's lower bound exceeds upper bound.") maxval = min(maxval, REGMAX) if minval == maxval: return ("==", minval) if minval == 0: return ("<=", maxval) if maxval == REGMAX: return (">=", minval) raise NotImplementedError("Range can only be bounded on one side.") def _negate_comparator(comparator: str) -> str: if comparator == "==": return "!=" if comparator == "!=": return "==" if comparator == "<=": return ">" if comparator == ">": return "<=" if comparator == ">=": return "<" assert comparator == "<" return ">=" def _get_optype_and_params(op: Op) -> tuple[OpType, list[float | Expr] | None]: optype = op.type params = ( op.params if (optype in _tk_to_qasm_params) or (optype in _tk_to_qasm_extra_params) else None ) if optype == OpType.TK1: # convert to U3 optype = OpType.U3 params = [op.params[1], op.params[0] - 0.5, op.params[2] + 0.5] elif optype == OpType.CustomGate: params = op.params return optype, params def _get_gate_circuit( optype: OpType, qubits: list[Qubit], symbols: list[Symbol] | None = None ) -> Circuit: # create Circuit for constructing qasm from unitids = cast(list[UnitID], qubits) gate_circ = Circuit() for q in qubits: gate_circ.add_qubit(q) if symbols: exprs = [symbol.as_expr() for symbol in symbols] gate_circ.add_gate(optype, exprs, unitids) else: gate_circ.add_gate(optype, unitids) AutoRebase({OpType.CX, OpType.U3}).apply(gate_circ) RemoveRedundancies().apply(gate_circ) return gate_circ def hqs_header(header: str) -> bool: return header in ["hqslib1", "hqslib1_dev"] @dataclass class ConditionString: variable: str # variable, e.g. "c[1]" comparator: str # comparator, e.g. "==" value: int # value, e.g. "1" class LabelledStringList: """ Wrapper class for an ordered sequence of strings, where each string has a unique label, returned when the string is added, and a string may be removed from the sequence given its label. There is a method to retrieve the concatenation of all strings in order. The conditions (e.g. "if(c[0]==1)") for some strings are stored separately in `conditions`. These conditions will be converted to text when retrieving the full string. """ def __init__(self) -> None: self.strings: OrderedDict[int, str] = OrderedDict() self.conditions: dict[int, ConditionString] = dict() self.label = 0 def add_string(self, string: str) -> int: label = self.label self.strings[label] = string self.label += 1 return label def get_string(self, label: int) -> str | None: return self.strings.get(label, None) def del_string(self, label: int) -> None: self.strings.pop(label, None) def get_full_string(self) -> str: strings = [] for l, s in self.strings.items(): condition = self.conditions.get(l) if condition is not None: strings.append( f"if({condition.variable}{condition.comparator}{condition.value}) " + s ) else: strings.append(s) return "".join(strings) def make_params_str(params: list[float | Expr] | None) -> str: s = "" if params is not None: n_params = len(params) s += "(" for i in range(n_params): reduced = True try: p: float | Expr = float(params[i]) except TypeError: reduced = False p = params[i] if i < n_params - 1: if reduced: s += f"{p}*pi," else: s += f"({p})*pi," elif reduced: s += f"{p}*pi)" else: s += f"({p})*pi)" s += " " return s def make_args_str(args: Sequence[UnitID]) -> str: s = "" for i in range(len(args)): s += f"{args[i]}" if i < len(args) - 1: s += "," else: s += ";\n" return s @dataclass class ScratchPredicate: variable: str # variable, e.g. "c[1]" comparator: str # comparator, e.g. "==" value: int # value, e.g. "1" dest: str # destination bit, e.g. "tk_SCRATCH_BIT[0]" def _vars_overlap(v: str, w: str) -> bool: """check if two variables have overlapping bits""" v_split = v.split("[") w_split = w.split("[") if v_split[0] != w_split[0]: # different registers return False # e.g. (a[1], a), (a, a[1]), (a[1], a[1]), (a, a) return len(v_split) != len(w_split) or v == w def _var_appears(v: str, s: str) -> bool: """check if variable v appears in string s""" v_split = v.split("[") if len(v_split) == 1: # check if v appears in s and is not surrounded by word characters # e.g. a = a & b or a = a[1] & b[1] return bool(re.search(r"(?<!\w)" + re.escape(v) + r"(?![\w])", s)) if re.search(r"(?<!\w)" + re.escape(v), s): # check if v appears in s and is not proceeded by word characters # e.g. a[1] = a[1] return True # check the register of v appears in s # e.g. a[1] = a & b return bool(re.search(r"(?<!\w)" + re.escape(v_split[0]) + r"(?![\[\w])", s)) class QasmWriter: """ Helper class for converting a sequence of TKET Commands to QASM, and retrieving the final QASM string afterwards. """ def __init__( self, qubits: list[Qubit], bits: list[Bit], header: str = "qelib1", include_gate_defs: set[str] | None = None, maxwidth: int = 32, ): self.header = header self.maxwidth = maxwidth self.added_gate_definitions: set[str] = set() self.include_module_gates = {"measure", "reset", "barrier"} self.include_module_gates.update( _load_include_module(header, False, True).keys() ) self.prefix = "" self.gatedefs = "" self.strings = LabelledStringList() # Record of `RangePredicate` operations that set a "scratch" bit to 0 or 1 # depending on the value of the predicate. This map is consulted when we # encounter a `Conditional` operation to see if the condition bit is one of # these scratch bits, which we can then replace with the original. self.range_preds: dict[int, ScratchPredicate] = dict() if include_gate_defs is None: self.include_gate_defs = self.include_module_gates self.include_gate_defs.update(NOPARAM_EXTRA_COMMANDS.keys()) self.include_gate_defs.update(PARAM_EXTRA_COMMANDS.keys()) self.prefix = f'OPENQASM 2.0;\ninclude "{header}.inc";\n\n' self.qregs = _retrieve_registers(cast(list[UnitID], qubits), QubitRegister) self.cregs = _retrieve_registers(cast(list[UnitID], bits), BitRegister) for reg in self.qregs.values(): if regname_regex.match(reg.name) is None: raise QASMUnsupportedError( f"Invalid register name '{reg.name}'. QASM register names must " "begin with a lowercase letter and may only contain lowercase " "and uppercase letters, numbers, and underscores. " "Try renaming the register with `rename_units` first." ) for bit_reg in self.cregs.values(): if regname_regex.match(bit_reg.name) is None: raise QASMUnsupportedError( f"Invalid register name '{bit_reg.name}'. QASM register names " "must begin with a lowercase letter and may only contain " "lowercase and uppercase letters, numbers, and underscores. " "Try renaming the register with `rename_units` first." ) else: # gate definition, no header necessary for file self.include_gate_defs = include_gate_defs self.cregs = {} self.qregs = {} self.cregs_as_bitseqs = set(tuple(creg) for creg in self.cregs.values()) # for holding condition values when writing Conditional blocks # the size changes when adding and removing scratch bits self.scratch_reg = BitRegister( next( f"{_TEMP_BIT_REG_BASE}_{i}" for i in itertools.count() if f"{_TEMP_BIT_REG_BASE}_{i}" not in self.qregs ), 0, ) # if a string writes to some classical variables, the string label and # the affected variables will be recorded. self.variable_writes: dict[int, list[str]] = dict() def fresh_scratch_bit(self) -> Bit: self.scratch_reg = BitRegister(self.scratch_reg.name, self.scratch_reg.size + 1) return Bit(self.scratch_reg.name, self.scratch_reg.size - 1) def remove_last_scratch_bit(self) -> None: assert self.scratch_reg.size > 0 self.scratch_reg = BitRegister(self.scratch_reg.name, self.scratch_reg.size - 1) def write_params(self, params: list[float | Expr] | None) -> None: params_str = make_params_str(params) self.strings.add_string(params_str) def write_args(self, args: Sequence[UnitID]) -> None: args_str = make_args_str(args) self.strings.add_string(args_str) def make_gate_definition( self, n_qubits: int, opstr: str, optype: OpType, n_params: int | None = None, ) -> str: s = "gate " + opstr + " " symbols: list[Symbol] | None = None if n_params is not None: # need to add parameters to gate definition s += "(" symbols = [ Symbol("param" + str(index) + "/pi") for index in range(n_params) ] symbols_header = [Symbol("param" + str(index)) for index in range(n_params)] for symbol in symbols_header[:-1]: s += symbol.name + ", " s += symbols_header[-1].name + ") " # add qubits to gate definition qubit_args = [ Qubit(opstr + "q" + str(index)) for index in list(range(n_qubits)) ] for qb in qubit_args[:-1]: s += str(qb) + "," s += str(qubit_args[-1]) + " {\n" # get rebased circuit for constructing qasm gate_circ = _get_gate_circuit(optype, qubit_args, symbols) # write circuit to qasm s += circuit_to_qasm_str( gate_circ, self.header, self.include_gate_defs, self.maxwidth ) s += "}\n" return s def mark_as_written(self, label: int, written_variable: str) -> None: if label in self.variable_writes: self.variable_writes[label].append(written_variable) else: self.variable_writes[label] = [written_variable] def check_range_predicate(self, op: RangePredicateOp, args: list[Bit]) -> None: if (not hqs_header(self.header)) and op.lower != op.upper: raise QASMUnsupportedError( "OpenQASM conditions must be on a register's fixed value." ) variable = args[0].reg_name assert isinstance(variable, str) if op.n_inputs != self.cregs[variable].size: raise QASMUnsupportedError( "RangePredicate conditions must be an entire classical register" ) if args[:-1] != self.cregs[variable].to_list(): raise QASMUnsupportedError( "RangePredicate conditions must be a single classical register" ) def add_range_predicate(self, op: RangePredicateOp, args: list[Bit]) -> None: self.check_range_predicate(op, args) comparator, value = _parse_range(op.lower, op.upper, self.maxwidth) variable = args[0].reg_name dest_bit = str(args[-1]) label = self.strings.add_string( "".join( [ f"if({variable}{comparator}{value}) " + f"{dest_bit} = 1;\n", f"if({variable}{_negate_comparator(comparator)}{value}) " f"{dest_bit} = 0;\n", ] ) ) # Record this operation. # Later if we find a conditional based on dest_bit, we can replace dest_bit with # (variable, comparator, value), provided that variable hasn't been written to # in the mean time. (So we must watch for that, and remove the record from the # list if it is.) # Note that we only perform such rewrites for internal scratch bits. if dest_bit.startswith(_TEMP_BIT_NAME): self.range_preds[label] = ScratchPredicate( variable, comparator, value, dest_bit ) def replace_condition(self, pred_label: int) -> bool: """Given the label of a predicate p=(var, comp, value, dest, label) we scan the lines after p: 1.if dest is the condition of a conditional line we replace dest with the predicate and do 2 for the inner command. 2.if either the variable or the dest gets written, we stop. returns true if a replacement is made. """ assert pred_label in self.range_preds success = False pred = self.range_preds[pred_label] line_labels = [] for label in range(pred_label + 1, self.strings.label): string = self.strings.get_string(label) if string is None: continue line_labels.append(label) if "\n" not in string: continue written_variables: list[str] = [] # (label, condition) conditions: list[tuple[int, ConditionString]] = [] for l in line_labels: written_variables.extend(self.variable_writes.get(l, [])) cond = self.strings.conditions.get(l) if cond: conditions.append((l, cond)) if len(conditions) == 1 and pred.dest == conditions[0][1].variable: # if the condition is dest, replace the condition with pred success = True if conditions[0][1].value == 1: self.strings.conditions[conditions[0][0]] = ConditionString( pred.variable, pred.comparator, pred.value ) else: assert conditions[0][1].value == 0 self.strings.conditions[conditions[0][0]] = ConditionString( pred.variable, _negate_comparator(pred.comparator), pred.value, ) if any(_vars_overlap(pred.dest, v) for v in written_variables) or any( _vars_overlap(pred.variable, v) for v in written_variables ): return success line_labels.clear() conditions.clear() written_variables.clear() return success def remove_unused_predicate(self, pred_label: int) -> bool: """Given the label of a predicate p=(var, comp, value, dest, label), we remove p if dest never appears after p.""" assert pred_label in self.range_preds pred = self.range_preds[pred_label] for label in range(pred_label + 1, self.strings.label): string = self.strings.get_string(label) if string is None: continue if ( _var_appears(pred.dest, string) or label in self.strings.conditions and _vars_overlap(pred.dest, self.strings.conditions[label].variable) ): return False self.range_preds.pop(pred_label) self.strings.del_string(pred_label) return True def add_conditional(self, op: Conditional, args: Sequence[UnitID]) -> None: control_bits = args[: op.width] if op.width == 1 and hqs_header(self.header): variable = str(control_bits[0]) else: variable = control_bits[0].reg_name if ( hqs_header(self.header) and control_bits != self.cregs[variable].to_list() ): raise QASMUnsupportedError( "hqslib1 QASM conditions must be an entire classical " "register or a single bit" ) if not hqs_header(self.header): if op.width != self.cregs[variable].size: raise QASMUnsupportedError( "OpenQASM conditions must be an entire classical register" ) if control_bits != self.cregs[variable].to_list(): raise QASMUnsupportedError( "OpenQASM conditions must be a single classical register" ) if op.op.type == OpType.Phase: # Conditional phase is ignored. return if op.op.type == OpType.RangePredicate: # Special handling for nested ifs # if condition # if pred dest = 1 # if not pred dest = 0 # can be written as # if condition s0 = 1 # if pred s1 = 1 # s2 = s0 & s1 # s3 = s0 & ~s1 # if s2 dest = 1 # if s3 dest = 0 # where s0, s1, s2, and s3 are scratch bits s0 = self.fresh_scratch_bit() l = self.strings.add_string(f"{s0} = 1;\n") # we store the condition in self.strings.conditions # as it can be later replaced by `replace_condition` # if possible self.strings.conditions[l] = ConditionString(variable, "==", op.value) # output the RangePredicate to s1 s1 = self.fresh_scratch_bit() assert isinstance(op.op, RangePredicateOp) self.check_range_predicate(op.op, cast(list[Bit], args[op.width :])) pred_comparator, pred_value = _parse_range( op.op.lower, op.op.upper, self.maxwidth ) pred_variable = args[op.width :][0].reg_name self.strings.add_string( f"if({pred_variable}{pred_comparator}{pred_value}) {s1} = 1;\n" ) s2 = self.fresh_scratch_bit() self.strings.add_string(f"{s2} = {s0} & {s1};\n") s3 = self.fresh_scratch_bit() self.strings.add_string(f"{s3} = {s0} & (~ {s1});\n") self.strings.add_string(f"if({s2}==1) {args[-1]} = 1;\n") self.strings.add_string(f"if({s3}==1) {args[-1]} = 0;\n") return # we assign the condition to a scratch bit, which we will later remove # if the condition variable is unchanged. scratch_bit = self.fresh_scratch_bit() pred_label = self.strings.add_string( f"if({variable}=={op.value}) " + f"{scratch_bit} = 1;\n" ) self.range_preds[pred_label] = ScratchPredicate( variable, "==", op.value, str(scratch_bit) ) # we will later add condition to all lines starting from next_label next_label = self.strings.label self.add_op(op.op, args[op.width :]) # add conditions to the lines after the predicate is_new_line = True for label in range(next_label, self.strings.label): string = self.strings.get_string(label) assert string is not None if is_new_line and string != "\n": self.strings.conditions[label] = ConditionString( str(scratch_bit), "==", 1 ) is_new_line = "\n" in string if self.replace_condition(pred_label) and self.remove_unused_predicate( pred_label ): # remove the unused scratch bit self.remove_last_scratch_bit() def add_set_bits(self, op: SetBitsOp, args: list[Bit]) -> None: creg_name = args[0].reg_name bits, vals = zip(*sorted(zip(args, op.values))) # check if whole register can be set at once if bits == tuple(self.cregs[creg_name].to_list()): value = int("".join(map(str, map(int, vals[::-1]))), 2) label = self.strings.add_string(f"{creg_name} = {value};\n") self.mark_as_written(label, f"{creg_name}") else: for bit, value in zip(bits, vals): label = self.strings.add_string(f"{bit} = {int(value)};\n") self.mark_as_written(label, f"{bit}") def add_copy_bits(self, op: CopyBitsOp, args: list[Bit]) -> None: l_args = args[op.n_inputs :] r_args = args[: op.n_inputs] l_name = l_args[0].reg_name r_name = r_args[0].reg_name # check if whole register can be set at once if ( l_args == self.cregs[l_name].to_list() and r_args == self.cregs[r_name].to_list() ): label = self.strings.add_string(f"{l_name} = {r_name};\n") self.mark_as_written(label, f"{l_name}") else: for bit_l, bit_r in zip(l_args, r_args): label = self.strings.add_string(f"{bit_l} = {bit_r};\n") self.mark_as_written(label, f"{bit_l}") def add_multi_bit(self, op: MultiBitOp, args: list[Bit]) -> None: basic_op = op.basic_op basic_n = basic_op.n_inputs + basic_op.n_outputs + basic_op.n_input_outputs n_args = len(args) assert n_args % basic_n == 0 arity = n_args // basic_n # If the operation is register-aligned we can write it more succinctly. poss_regs = [ tuple(args[basic_n * i + j] for i in range(arity)) for j in range(basic_n) ] if all(poss_reg in self.cregs_as_bitseqs for poss_reg in poss_regs): # The operation is register-aligned. self.add_op(basic_op, [poss_regs[j][0].reg_name for j in range(basic_n)]) # type: ignore else: # The operation is not register-aligned. for i in range(arity): basic_args = args[basic_n * i : basic_n * (i + 1)] self.add_op(basic_op, basic_args) def add_explicit_op(self, op: Op, args: list[Bit]) -> None: # &, ^ and | gates opstr = str(op) if opstr not in _classical_gatestr_map: raise QASMUnsupportedError(f"Classical gate {opstr} not supported.") label = self.strings.add_string( f"{args[-1]} = {args[0]} {_classical_gatestr_map[opstr]} {args[1]};\n" ) self.mark_as_written(label, f"{args[-1]}") def add_classical_exp_box(self, op: ClassicalExpBox, args: list[Bit]) -> None: out_args = args[op.get_n_i() :] if len(out_args) == 1: label = self.strings.add_string(f"{out_args[0]} = {op.get_exp()!s};\n") self.mark_as_written(label, f"{out_args[0]}") elif ( out_args == self.cregs[out_args[0].reg_name].to_list()[ : op.get_n_io() + op.get_n_o() ] ): label = self.strings.add_string( f"{out_args[0].reg_name} = {op.get_exp()!s};\n" ) self.mark_as_written(label, f"{out_args[0].reg_name}") else: raise QASMUnsupportedError( "ClassicalExpBox only supported" " for writing to a single bit or whole registers." ) def add_wired_clexpr(self, op: ClExprOp, args: list[Bit]) -> None: wexpr: WiredClExpr = op.expr # 1. Determine the mappings from bit variables to bits and from register # variables to registers. expr: ClExpr = wexpr.expr bit_posn: dict[int, int] = wexpr.bit_posn reg_posn: dict[int, list[int]] = wexpr.reg_posn output_posn: list[int] = wexpr.output_posn input_bits: dict[int, Bit] = {i: args[j] for i, j in bit_posn.items()} input_regs: dict[int, BitRegister] = {} all_cregs = set(self.cregs.values()) for i, posns in reg_posn.items(): reg_args = [args[j] for j in posns] for creg in all_cregs: if creg.to_list() == reg_args: input_regs[i] = creg break else: assert ( not f"ClExprOp ({wexpr}) contains a register variable (r{i}) that " "is not wired to any BitRegister in the circuit." ) # 2. Write the left-hand side of the assignment. output_repr: str | None = None output_args: list[Bit] = [args[j] for j in output_posn] n_output_args = len(output_args) expect_reg_output = has_reg_output(expr.op) if n_output_args == 0: raise QASMUnsupportedError("Expression has no output.") if n_output_args == 1: output_arg = output_args[0] output_repr = output_arg.reg_name if expect_reg_output else str(output_arg) else: if not expect_reg_output: raise QASMUnsupportedError("Unexpected output for operation.") for creg in all_cregs: if creg.to_list() == output_args: output_repr = creg.name break assert output_repr is not None self.strings.add_string(f"{output_repr} = ") # 3. Write the right-hand side of the assignment. self.strings.add_string( expr.as_qasm(input_bits=input_bits, input_regs=input_regs) ) self.strings.add_string(";\n") def add_wasm(self, op: WASMOp, args: list[Bit]) -> None: inputs: list[str] = [] outputs: list[str] = [] for reglist, sizes in [(inputs, op.input_widths), (outputs, op.output_widths)]: for in_width in sizes: bits = args[:in_width] args = args[in_width:] regname = bits[0].reg_name if bits != list(self.cregs[regname]): QASMUnsupportedError("WASM ops must act on entire registers.") reglist.append(regname) if outputs: label = self.strings.add_string(f"{', '.join(outputs)} = ") self.strings.add_string(f"{op.func_name}({', '.join(inputs)});\n") for variable in outputs: self.mark_as_written(label, variable) def add_measure(self, args: Sequence[UnitID]) -> None: label = self.strings.add_string(f"measure {args[0]} -> {args[1]};\n") self.mark_as_written(label, f"{args[1]}") def add_zzphase(self, param: float | Expr, args: Sequence[UnitID]) -> None: # as op.params returns reduced parameters, we can assume # that 0 <= param < 4 if param > 1: # first get in to 0 <= param < 2 range param = Decimal(str(param)) % Decimal("2") # then flip 1 <= param < 2 range into # -1 <= param < 0 if param > 1: param = -2 + param self.strings.add_string("RZZ") self.write_params([param]) self.write_args(args) def add_data(self, op: BarrierOp, args: Sequence[UnitID]) -> None: if op.data == "": opstr = _tk_to_qasm_noparams[OpType.Barrier] else: opstr = op.data self.strings.add_string(opstr) self.strings.add_string(" ") self.write_args(args) def add_gate_noparams(self, op: Op, args: Sequence[UnitID]) -> None: self.strings.add_string(_tk_to_qasm_noparams[op.type]) self.strings.add_string(" ") self.write_args(args) def add_gate_params(self, op: Op, args: Sequence[UnitID]) -> None: optype, params = _get_optype_and_params(op) self.strings.add_string(_tk_to_qasm_params[optype]) self.write_params(params) self.write_args(args) def add_extra_noparams(self, op: Op, args: Sequence[UnitID]) -> tuple[str, str]: optype = op.type opstr = _tk_to_qasm_extra_noparams[optype] gatedefstr = "" if opstr not in self.added_gate_definitions: self.added_gate_definitions.add(opstr) gatedefstr = self.make_gate_definition(op.n_qubits, opstr, optype) mainstr = opstr + " " + make_args_str(args) return gatedefstr, mainstr def add_extra_params(self, op: Op, args: Sequence[UnitID]) -> tuple[str, str]: optype, params = _get_optype_and_params(op) assert params is not None opstr = _tk_to_qasm_extra_params[optype] gatedefstr = "" if opstr not in self.added_gate_definitions: self.added_gate_definitions.add(opstr) gatedefstr = self.make_gate_definition( op.n_qubits, opstr, optype, len(params) ) mainstr = opstr + make_params_str(params) + make_args_str(args) return gatedefstr, mainstr def add_op(self, op: Op, args: Sequence[UnitID]) -> None: optype, _params = _get_optype_and_params(op) if optype == OpType.RangePredicate: assert isinstance(op, RangePredicateOp) self.add_range_predicate(op, cast(list[Bit], args)) elif optype == OpType.Conditional: assert isinstance(op, Conditional) self.add_conditional(op, args) elif optype == OpType.Phase: # global phase is ignored in QASM pass elif optype == OpType.SetBits: assert isinstance(op, SetBitsOp) self.add_set_bits(op, cast(list[Bit], args)) elif optype == OpType.CopyBits: assert isinstance(op, CopyBitsOp) self.add_copy_bits(op, cast(list[Bit], args)) elif optype == OpType.MultiBit: assert isinstance(op, MultiBitOp) self.add_multi_bit(op, cast(list[Bit], args)) elif optype in (OpType.ExplicitPredicate, OpType.ExplicitModifier): self.add_explicit_op(op, cast(list[Bit], args)) elif optype == OpType.ClassicalExpBox: assert isinstance(op, ClassicalExpBox) self.add_classical_exp_box(op, cast(list[Bit], args)) elif optype == OpType.ClExpr: assert isinstance(op, ClExprOp) self.add_wired_clexpr(op, cast(list[Bit], args)) elif optype == OpType.WASM: assert isinstance(op, WASMOp) self.add_wasm(op, cast(list[Bit], args)) elif optype == OpType.Measure: self.add_measure(args) elif hqs_header(self.header) and optype == OpType.ZZPhase: # special handling for zzphase assert len(op.params) == 1 self.add_zzphase(op.params[0], args) elif optype == OpType.Barrier and self.header == "hqslib1_dev": assert isinstance(op, BarrierOp) self.add_data(op, args) elif ( optype in _tk_to_qasm_noparams and _tk_to_qasm_noparams[optype] in self.include_module_gates ): self.add_gate_noparams(op, args) elif ( optype in _tk_to_qasm_params and _tk_to_qasm_params[optype] in self.include_module_gates ): self.add_gate_params(op, args) elif optype in _tk_to_qasm_extra_noparams: gatedefstr, mainstr = self.add_extra_noparams(op, args) self.gatedefs += gatedefstr self.strings.add_string(mainstr) elif optype in _tk_to_qasm_extra_params: gatedefstr, mainstr = self.add_extra_params(op, args) self.gatedefs += gatedefstr self.strings.add_string(mainstr) else: raise QASMUnsupportedError(f"Cannot print command of type: {op.get_name()}") def finalize(self) -> str: # try removing unused predicates pred_labels = list(self.range_preds.keys()) for label in pred_labels: # try replacing conditions with a predicate self.replace_condition(label) # try removing the predicate self.remove_unused_predicate(label) reg_strings = LabelledStringList() for reg in self.qregs.values(): reg_strings.add_string(f"qreg {reg.name}[{reg.size}];\n") for bit_reg in self.cregs.values(): reg_strings.add_string(f"creg {bit_reg.name}[{bit_reg.size}];\n") if self.scratch_reg.size > 0: reg_strings.add_string( f"creg {self.scratch_reg.name}[{self.scratch_reg.size}];\n" ) return ( self.prefix + self.gatedefs + _filtered_qasm_str( reg_strings.get_full_string() + self.strings.get_full_string() ) )
[docs] def circuit_to_qasm_io( circ: Circuit, stream_out: TextIO, header: str = "qelib1", include_gate_defs: set[str] | None = None, maxwidth: int = 32, ) -> None: """Convert a Circuit to QASM and write to a text stream. Classical bits in the pytket circuit must be singly-indexed. Note that this will not account for implicit qubit permutations in the Circuit. :param circ: pytket circuit :param stream_out: text stream to be written to :param header: qasm header (default "qelib1") :param include_gate_defs: optional set of gates to include :param maxwidth: maximum allowed width of classical registers (default 32) """ stream_out.write( circuit_to_qasm_str( circ, header=header, include_gate_defs=include_gate_defs, maxwidth=maxwidth ) )