# Copyright 2019-2024 Cambridge Quantum Computing
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import os
import re
import uuid
import itertools
from collections import OrderedDict
from importlib import import_module
from itertools import chain, groupby
from decimal import Decimal
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
NewType,
Optional,
Sequence,
Set,
TextIO,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from sympy import Symbol, pi, Expr
from lark import Discard, Lark, Token, Transformer, Tree
from pytket._tket.circuit import (
ClassicalExpBox,
Command,
Conditional,
RangePredicateOp,
SetBitsOp,
CopyBitsOp,
MultiBitOp,
WASMOp,
BarrierOp,
)
from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE
from pytket.circuit import (
Bit,
BitRegister,
Circuit,
Op,
OpType,
Qubit,
QubitRegister,
UnitID,
)
from pytket.circuit.decompose_classical import int_to_bools
from pytket.circuit.logic_exp import (
BitLogicExp,
BitWiseOp,
PredicateExp,
LogicExp,
RegEq,
RegLogicExp,
RegNeg,
RegWiseOp,
create_predicate_exp,
create_logic_exp,
)
from pytket.qasm.grammar import grammar
from pytket.passes import AutoRebase, DecomposeBoxes, RemoveRedundancies
from pytket.wasm import WasmFileHandler
class QASMParseError(Exception):
"""Error while parsing QASM input."""
def __init__(
self, msg: str, line: Optional[int] = None, fname: Optional[str] = None
):
self.msg = msg
self.line = line
self.fname = fname
ctx = "" if fname is None else f"\nFile:{fname}: "
ctx += "" if line is None else f"\nLine:{line}. "
super().__init__(f"{msg}{ctx}")
class QASMUnsupportedError(Exception):
pass
Value = Union[int, float, str]
T = TypeVar("T")
_BITOPS = set(op.value for op in BitWiseOp)
_BITOPS.update(("+", "-")) # both are parsed to XOR
_REGOPS = set(op.value for op in RegWiseOp)
Arg = Union[List, str]
NOPARAM_COMMANDS = {
"CX": OpType.CX, # built-in gate equivalent to "cx"
"cx": OpType.CX,
"x": OpType.X,
"y": OpType.Y,
"z": OpType.Z,
"h": OpType.H,
"s": OpType.S,
"sdg": OpType.Sdg,
"t": OpType.T,
"tdg": OpType.Tdg,
"sx": OpType.SX,
"sxdg": OpType.SXdg,
"cz": OpType.CZ,
"cy": OpType.CY,
"ch": OpType.CH,
"csx": OpType.CSX,
"ccx": OpType.CCX,
"c3x": OpType.CnX,
"c4x": OpType.CnX,
"ZZ": OpType.ZZMax,
"measure": OpType.Measure,
"reset": OpType.Reset,
"id": OpType.noop,
"barrier": OpType.Barrier,
"swap": OpType.SWAP,
"cswap": OpType.CSWAP,
}
PARAM_COMMANDS = {
"p": OpType.U1, # alias. https://github.com/Qiskit/qiskit-terra/pull/4765
"u": OpType.U3, # alias. https://github.com/Qiskit/qiskit-terra/pull/4765
"U": OpType.U3, # built-in gate equivalent to "u3"
"u3": OpType.U3,
"u2": OpType.U2,
"u1": OpType.U1,
"rx": OpType.Rx,
"rxx": OpType.XXPhase,
"ry": OpType.Ry,
"rz": OpType.Rz,
"RZZ": OpType.ZZPhase,
"rzz": OpType.ZZPhase,
"Rz": OpType.Rz,
"U1q": OpType.PhasedX,
"crz": OpType.CRz,
"crx": OpType.CRx,
"cry": OpType.CRy,
"cu1": OpType.CU1,
"cu3": OpType.CU3,
"Rxxyyzz": OpType.TK2,
}
NOPARAM_EXTRA_COMMANDS = {
"v": OpType.V,
"vdg": OpType.Vdg,
"cv": OpType.CV,
"cvdg": OpType.CVdg,
"csxdg": OpType.CSXdg,
"bridge": OpType.BRIDGE,
"iswapmax": OpType.ISWAPMax,
"zzmax": OpType.ZZMax,
"ecr": OpType.ECR,
"cs": OpType.CS,
"csdg": OpType.CSdg,
}
PARAM_EXTRA_COMMANDS = {
"tk2": OpType.TK2,
"iswap": OpType.ISWAP,
"phasediswap": OpType.PhasedISWAP,
"yyphase": OpType.YYPhase,
"xxphase3": OpType.XXPhase3,
"eswap": OpType.ESWAP,
"fsim": OpType.FSim,
}
_tk_to_qasm_noparams = dict(((item[1], item[0]) for item in NOPARAM_COMMANDS.items()))
_tk_to_qasm_noparams[OpType.CX] = "cx" # prefer "cx" to "CX"
_tk_to_qasm_params = dict(((item[1], item[0]) for item in PARAM_COMMANDS.items()))
_tk_to_qasm_params[OpType.U3] = "u3" # prefer "u3" to "U"
_tk_to_qasm_params[OpType.Rz] = "rz" # prefer "rz" to "Rz"
_tk_to_qasm_extra_noparams = dict(
((item[1], item[0]) for item in NOPARAM_EXTRA_COMMANDS.items())
)
_tk_to_qasm_extra_params = dict(
((item[1], item[0]) for item in PARAM_EXTRA_COMMANDS.items())
)
_classical_gatestr_map = {"AND": "&", "OR": "|", "XOR": "^"}
_all_known_gates = (
set(NOPARAM_COMMANDS.keys())
.union(PARAM_COMMANDS.keys())
.union(PARAM_EXTRA_COMMANDS.keys())
.union(NOPARAM_EXTRA_COMMANDS.keys())
)
_all_string_maps = {
key: val.name
for key, val in chain(
PARAM_COMMANDS.items(),
NOPARAM_COMMANDS.items(),
PARAM_EXTRA_COMMANDS.items(),
NOPARAM_EXTRA_COMMANDS.items(),
)
}
unit_regex = re.compile(r"([a-z][a-zA-Z0-9_]*)\[([\d]+)\]")
regname_regex = re.compile(r"^[a-z][a-zA-Z0-9_]*$")
def _extract_reg(var: Token) -> Tuple[str, int]:
match = unit_regex.match(var.value)
if match is None:
raise QASMParseError(
f"Invalid register definition '{var.value}'. Register definitions "
"must follow the pattern '<name> [<size in integer>]'. "
"For example, 'q [5]'. QASM register names must begin with a "
"lowercase letter and may only contain lowercase and uppercase "
"letters, numbers, and underscores."
)
return match.group(1), int(match.group(2))
def _load_include_module(
header_name: str, flter: bool, decls_only: bool
) -> Dict[str, Dict]:
try:
if decls_only:
include_def: Dict[str, Dict] = import_module(
f"pytket.qasm.includes._{header_name}_decls"
)._INCLUDE_DECLS
else:
include_def = import_module(
f"pytket.qasm.includes._{header_name}_defs"
)._INCLUDE_DEFS
except ModuleNotFoundError as e:
raise QASMParseError(
f"Header {header_name} is not known and cannot be loaded."
) from e
return {
gate: include_def[gate]
for gate in include_def
if not flter or gate not in _all_known_gates
}
def _bin_par_exp(op: "str") -> Callable[["CircuitTransformer", List[str]], str]:
def f(self: "CircuitTransformer", vals: List[str]) -> str:
return f"({vals[0]} {op} {vals[1]})"
return f
def _un_par_exp(op: "str") -> Callable[["CircuitTransformer", List[str]], str]:
def f(self: "CircuitTransformer", vals: List[str]) -> str:
return f"({op}{vals[0]})"
return f
def _un_call_exp(op: "str") -> Callable[["CircuitTransformer", List[str]], str]:
def f(self: "CircuitTransformer", vals: List[str]) -> str:
return f"{op}({vals[0]})"
return f
def _hashable_uid(arg: List) -> Tuple[str, int]:
return arg[0], arg[1][0]
Reg = NewType("Reg", str)
CommandDict = Dict[str, Any]
@dataclass
class ParsMap:
pars: Iterable[str]
def __iter__(self) -> Iterable[str]:
return self.pars
class CircuitTransformer(Transformer):
def __init__(self, return_gate_dict: bool = False, maxwidth: int = 32) -> None:
super().__init__()
self.q_registers: Dict[str, int] = {}
self.c_registers: Dict[str, int] = {}
self.gate_dict: Dict[str, Dict] = {}
self.wasm: Optional[WasmFileHandler] = None
self.include = ""
self.return_gate_dict = return_gate_dict
self.maxwidth = maxwidth
def _fresh_temp_bit(self) -> List:
if _TEMP_BIT_NAME in self.c_registers:
idx = self.c_registers[_TEMP_BIT_NAME]
else:
idx = 0
self.c_registers[_TEMP_BIT_NAME] = idx + 1
return [_TEMP_BIT_NAME, [idx]]
def _reset_context(self, reset_wasm: bool = True) -> None:
self.q_registers = {}
self.c_registers = {}
self.gate_dict = {}
self.include = ""
if reset_wasm:
self.wasm = None
def _get_reg(self, name: str) -> Reg:
return Reg(name)
def _get_uid(self, iarg: Token) -> List:
name, idx = _extract_reg(iarg)
return [name, [idx]]
def _get_arg(self, arg: Token) -> Arg:
if arg.type == "IARG":
return self._get_uid(arg)
else:
return self._get_reg(arg.value)
def unroll_all_args(self, args: Iterable[Arg]) -> Iterator[List[Any]]:
for arg in args:
if isinstance(arg, str):
size = (
self.q_registers[arg]
if arg in self.q_registers
else self.c_registers[arg]
)
yield [[arg, [idx]] for idx in range(size)]
else:
yield [arg]
def margs(self, tree: Iterable[Token]) -> Iterator[Arg]:
return map(self._get_arg, tree)
def iargs(self, tree: Iterable[Token]) -> Iterator[List]:
return map(self._get_uid, tree)
def args(self, tree: Iterable[Token]) -> Iterator[List]:
return ([tok.value, [0]] for tok in tree)
def creg(self, tree: List[Token]) -> None:
name, size = _extract_reg(tree[0])
if size > self.maxwidth:
raise QASMUnsupportedError(
f"Circuit contains classical register {name} of size {size} > "
f"{self.maxwidth}: try setting the `maxwidth` parameter to a larger "
"value."
)
self.c_registers[Reg(name)] = size
def qreg(self, tree: List[Token]) -> None:
name, size = _extract_reg(tree[0])
self.q_registers[Reg(name)] = size
def meas(self, tree: List[Token]) -> Iterable[CommandDict]:
for args in zip(*self.unroll_all_args(self.margs(tree))):
yield {"args": list(args), "op": {"type": "Measure"}}
def barr(self, tree: List[Arg]) -> Iterable[CommandDict]:
args = [q for qs in self.unroll_all_args(tree[0]) for q in qs]
signature: List[str] = []
for arg in args:
if arg[0] in self.c_registers:
signature.append("C")
elif arg[0] in self.q_registers:
signature.append("Q")
else:
raise QASMParseError(
"UnitID " + str(arg) + " in Barrier arguments is not declared."
)
yield {
"args": args,
"op": {"signature": signature, "type": "Barrier"},
}
def reset(self, tree: List[Token]) -> Iterable[CommandDict]:
for qb in next(self.unroll_all_args(self.margs(tree))):
yield {"args": [qb], "op": {"type": "Reset"}}
def pars(self, vals: Iterable[str]) -> ParsMap:
return ParsMap(map(str, vals))
def mixedcall(self, tree: List) -> Iterator[CommandDict]:
child_iter = iter(tree)
optoken = next(child_iter)
opstr = optoken.value
next_tree = next(child_iter)
try:
args = next(child_iter)
pars = cast(ParsMap, next_tree).pars
except StopIteration:
args = next_tree
pars = []
treat_as_barrier = [
"sleep",
"order2",
"order3",
"order4",
"order5",
"order6",
"order7",
"order8",
"order9",
"order10",
"order11",
"order12",
"order13",
"order14",
"order15",
"order16",
"order17",
"order18",
"order19",
"order20",
"group2",
"group3",
"group4",
"group5",
"group6",
"group7",
"group8",
"group9",
"group10",
"group11",
"group12",
"group13",
"group14",
"group15",
"group16",
"group17",
"group18",
"group19",
"group20",
]
# other opaque gates, which are not handled as barrier
# ["RZZ", "Rxxyyzz", "Rxxyyzz_zphase", "cu", "cp", "rccx", "rc3x", "c3sqrtx"]
args = list(args)
if opstr in treat_as_barrier:
params = [f"{par}" for par in pars]
else:
params = [f"({par})/pi" for par in pars]
if opstr in self.gate_dict:
op: Dict[str, Any] = {}
if opstr in treat_as_barrier:
op["type"] = "Barrier"
param_sorted = ",".join(params)
op["data"] = f"{opstr}({param_sorted})"
op["signature"] = [arg[0] for arg in args]
else:
gdef = self.gate_dict[opstr]
op["type"] = "CustomGate"
box = {
"type": "CustomGate",
"id": str(uuid.uuid4()),
"gate": gdef,
}
box["params"] = params
op["box"] = box
params = [] # to stop duplication in to op
else:
try:
optype = _all_string_maps[opstr]
except KeyError as e:
raise QASMParseError(
"Cannot parse gate of type: {}".format(opstr), optoken.line
) from e
op = {"type": optype}
if params:
op["params"] = params
# Operations needing special handling:
if optype.startswith("Cn"):
# n-controlled rotations have variable signature
op["n_qb"] = len(args)
elif optype == "Barrier":
op["signature"] = ["Q"] * len(args)
for arg in zip(*self.unroll_all_args(args)):
yield {"args": list(arg), "op": op}
def gatecall(self, tree: List) -> Iterable[CommandDict]:
return self.mixedcall(tree)
def exp_args(self, tree: Iterable[Token]) -> Iterable[Reg]:
for arg in tree:
if arg.type == "ARG":
yield self._get_reg(arg.value)
else:
raise QASMParseError(
"Non register arguments not supported for extern call.", arg.line
)
def _logic_exp(self, tree: List, opstr: str) -> LogicExp:
args, line = self._get_logic_args(tree)
openum: Union[Type[BitWiseOp], Type[RegWiseOp]]
if opstr in _BITOPS and opstr not in _REGOPS:
openum = BitWiseOp
elif opstr in _REGOPS and opstr not in _BITOPS:
openum = RegWiseOp
elif all(isinstance(arg, int) for arg in args):
openum = RegWiseOp
elif all(isinstance(arg, (Bit, BitLogicExp, int)) for arg in args):
if all(arg in (0, 1) for arg in args if isinstance(arg, int)):
openum = BitWiseOp
else:
raise QASMParseError(
"Bits can only be operated with (0, 1) literals."
f" Incomaptible arguments {args}",
line,
)
else:
openum = RegWiseOp
if openum is BitWiseOp and opstr in ("+", "-"):
op: Union[BitWiseOp, RegWiseOp] = BitWiseOp.XOR
else:
op = openum(opstr)
return create_logic_exp(op, args)
def _get_logic_args(
self, tree: Sequence[Union[Token, LogicExp]]
) -> Tuple[List[Union[LogicExp, Bit, BitRegister, int]], Optional[int]]:
args: List[Union[LogicExp, Bit, BitRegister, int]] = []
line = None
for tok in tree:
if isinstance(tok, LogicExp):
args.append(tok)
elif isinstance(tok, Token):
line = tok.line
if tok.type == "INT":
args.append(int(tok.value))
elif tok.type == "IARG":
args.append(Bit(*_extract_reg(tok)))
elif tok.type == "ARG":
args.append(BitRegister(tok.value, self.c_registers[tok.value]))
else:
raise QASMParseError(f"Could not pass argument {tok}")
else:
raise QASMParseError(f"Could not pass argument {tok}")
return args, line
par_add = _bin_par_exp("+")
par_sub = _bin_par_exp("-")
par_mul = _bin_par_exp("*")
par_div = _bin_par_exp("/")
par_pow = _bin_par_exp("**")
par_neg = _un_par_exp("-")
sqrt = _un_call_exp("sqrt")
sin = _un_call_exp("sin")
cos = _un_call_exp("cos")
tan = _un_call_exp("tan")
ln = _un_call_exp("ln")
b_and = lambda self, tree: self._logic_exp(tree, "&")
b_not = lambda self, tree: self._logic_exp(tree, "~")
b_or = lambda self, tree: self._logic_exp(tree, "|")
xor = lambda self, tree: self._logic_exp(tree, "^")
lshift = lambda self, tree: self._logic_exp(tree, "<<")
rshift = lambda self, tree: self._logic_exp(tree, ">>")
add = lambda self, tree: self._logic_exp(tree, "+")
sub = lambda self, tree: self._logic_exp(tree, "-")
mul = lambda self, tree: self._logic_exp(tree, "*")
div = lambda self, tree: self._logic_exp(tree, "/")
ipow = lambda self, tree: self._logic_exp(tree, "**")
def neg(self, tree: List[Union[Token, LogicExp]]) -> RegNeg:
arg = self._get_logic_args(tree)[0][0]
assert isinstance(arg, (RegLogicExp, BitRegister, int))
return RegNeg(arg)
def cond(self, tree: List[Token]) -> PredicateExp:
op: Union[BitWiseOp, RegWiseOp]
arg: Union[Bit, BitRegister]
if tree[1].type == "IARG":
arg = Bit(*_extract_reg(tree[1]))
op = BitWiseOp(str(tree[2]))
else:
arg = BitRegister(tree[1].value, self.c_registers[tree[1].value])
op = RegWiseOp(str(tree[2]))
return create_predicate_exp(op, [arg, int(tree[3].value)])
def ifc(self, tree: Sequence) -> Iterable[CommandDict]:
condition = cast(PredicateExp, tree[0])
var, val = condition.args
condition_bits = []
if isinstance(var, Bit):
assert condition.op in (BitWiseOp.EQ, BitWiseOp.NEQ)
assert val in (0, 1)
condition_bits = [var.to_list()]
else:
assert isinstance(var, BitRegister)
reg_bits = next(self.unroll_all_args([var.name]))
if isinstance(condition, RegEq):
# special case for base qasm
condition_bits = reg_bits
else:
pred_val = cast(int, val)
minval = 0
maxval = (1 << self.maxwidth) - 1
if condition.op == RegWiseOp.LT:
maxval = pred_val - 1
elif condition.op == RegWiseOp.GT:
minval = pred_val + 1
if condition.op in (RegWiseOp.LEQ, RegWiseOp.EQ, RegWiseOp.NEQ):
maxval = pred_val
if condition.op in (RegWiseOp.GEQ, RegWiseOp.EQ, RegWiseOp.NEQ):
minval = pred_val
condition_bit = self._fresh_temp_bit()
yield {
"args": reg_bits + [condition_bit],
"op": {
"classical": {
"lower": minval,
"n_i": len(reg_bits),
"upper": maxval,
},
"type": "RangePredicate",
},
}
condition_bits = [condition_bit]
val = int(condition.op != RegWiseOp.NEQ)
for com in filter(lambda x: x is not None and x is not Discard, tree[1]):
com["args"] = condition_bits + com["args"]
com["op"] = {
"conditional": {
"op": com["op"],
"value": val,
"width": len(condition_bits),
},
"type": "Conditional",
}
yield com
def cop(self, tree: Sequence[Iterable[CommandDict]]) -> Iterable[CommandDict]:
return tree[0]
def _calc_exp_io(
self, exp: LogicExp, out_args: List
) -> Tuple[List[List], Dict[str, Any]]:
all_inps: list[Tuple[str, int]] = []
for inp in exp.all_inputs_ordered():
if isinstance(inp, Bit):
all_inps.append((inp.reg_name, inp.index[0]))
else:
assert isinstance(inp, BitRegister)
for bit in inp:
all_inps.append((bit.reg_name, bit.index[0]))
outs = (_hashable_uid(arg) for arg in out_args)
o = []
io = []
for out in outs:
if out in all_inps:
all_inps.remove(out)
io.append(out)
else:
o.append(out)
exp_args = list(
map(lambda x: [x[0], [x[1]]], chain.from_iterable((all_inps, io, o)))
)
numbers_dict = {
"n_i": len(all_inps),
"n_io": len(io),
"n_o": len(o),
}
return exp_args, numbers_dict
def _cexpbox_dict(self, exp: LogicExp, args: List[List]) -> CommandDict:
box = {
"exp": exp.to_dict(),
"id": str(uuid.uuid4()),
"type": "ClassicalExpBox",
}
args, numbers = self._calc_exp_io(exp, args)
box.update(numbers)
return {
"args": args,
"op": {
"box": box,
"type": "ClassicalExpBox",
},
}
def assign(self, tree: List) -> Iterable[CommandDict]:
child_iter = iter(tree)
out_args = list(next(child_iter))
args_uids = list(self.unroll_all_args(out_args))
exp_tree = next(child_iter)
exp: Union[str, List, LogicExp, int] = ""
line = None
if isinstance(exp_tree, Token):
if exp_tree.type == "INT":
exp = int(exp_tree.value)
elif exp_tree.type in ("ARG", "IARG"):
exp = self._get_arg(exp_tree)
line = exp_tree.line
elif isinstance(exp_tree, Generator):
# assume to be extern (wasm) call
chained_uids = list(chain.from_iterable(args_uids))
com = next(exp_tree)
com["args"].pop() # remove the wasmstate from the args
com["args"] += chained_uids
com["args"].append(["_w", [0]])
com["op"]["wasm"]["n"] += len(chained_uids)
com["op"]["wasm"]["width_o_parameter"] = [
self.c_registers[reg] for reg in out_args
]
yield com
return
else:
exp = exp_tree
assert len(out_args) == 1
out_arg = out_args[0]
args = args_uids[0]
if isinstance(out_arg, List):
if isinstance(exp, LogicExp):
yield self._cexpbox_dict(exp, args)
elif isinstance(exp, (int, bool)):
assert exp in (0, 1, True, False)
yield {
"args": args,
"op": {"classical": {"values": [bool(exp)]}, "type": "SetBits"},
}
elif isinstance(exp, List):
yield {
"args": [exp] + args,
"op": {"classical": {"n_i": 1}, "type": "CopyBits"},
}
else:
raise QASMParseError(f"Unexpected expression in assignment {exp}", line)
else:
reg = out_arg
if isinstance(exp, RegLogicExp):
yield self._cexpbox_dict(exp, args)
elif isinstance(exp, BitLogicExp):
yield self._cexpbox_dict(exp, args[:1])
elif isinstance(exp, int):
yield {
"args": args,
"op": {
"classical": {
"values": int_to_bools(exp, self.c_registers[reg])
},
"type": "SetBits",
},
}
elif isinstance(exp, str):
width = min(self.c_registers[exp], len(args))
yield {
"args": [[exp, [i]] for i in range(width)] + args[:width],
"op": {"classical": {"n_i": width}, "type": "CopyBits"},
}
else:
raise QASMParseError(f"Unexpected expression in assignment {exp}", line)
def extern(self, tree: List[Any]) -> Any:
# TODO parse extern defs
return Discard
def ccall(self, tree: List) -> Iterable[CommandDict]:
return self.cce_call(tree)
def cce_call(self, tree: List) -> Iterable[CommandDict]:
nam = tree[0].value
params = list(tree[1])
if self.wasm is None:
raise QASMParseError(
"Cannot include extern calls without a wasm module specified.",
tree[0].line,
)
n_i_vec = [self.c_registers[reg] for reg in params]
wasm_args = list(chain.from_iterable(self.unroll_all_args(params)))
wasm_args.append(["_w", [0]])
yield {
"args": wasm_args,
"op": {
"type": "WASM",
"wasm": {
"func_name": nam,
"ww_n": 1,
"n": sum(n_i_vec),
"width_i_parameter": n_i_vec,
"width_o_parameter": [], # this will be set in the assign function
"wasm_file_uid": str(self.wasm),
},
},
}
def transform(self, tree: Tree) -> Dict[str, Any]:
self._reset_context()
return cast(Dict[str, Any], super().transform(tree))
def gdef(self, tree: List) -> None:
child_iter = iter(tree)
gate = next(child_iter).value
next_tree = next(child_iter)
symbols, args = [], []
if isinstance(next_tree, ParsMap):
symbols = list(next_tree.pars)
args = list(next(child_iter))
else:
args = list(next_tree)
symbol_map = {sym: sym * pi for sym in map(Symbol, symbols)}
rename_map = {Qubit.from_list(qb): Qubit("q", i) for i, qb in enumerate(args)}
new = CircuitTransformer(maxwidth=self.maxwidth)
circ_dict = new.prog(child_iter)
circ_dict["qubits"] = args
gate_circ = Circuit.from_dict(circ_dict)
# check to see whether gate definition was generated by pytket converter
# if true, add op as pytket Op
existing_op: bool = False
if gate in NOPARAM_EXTRA_COMMANDS:
qubit_args = [
Qubit(gate + "q" + str(index), 0) for index in list(range(len(args)))
]
comparison_circ = _get_gate_circuit(
NOPARAM_EXTRA_COMMANDS[gate], qubit_args
)
if circuit_to_qasm_str(
comparison_circ, maxwidth=self.maxwidth
) == circuit_to_qasm_str(gate_circ, maxwidth=self.maxwidth):
existing_op = True
elif gate in PARAM_EXTRA_COMMANDS:
qubit_args = [
Qubit(gate + "q" + str(index), 0) for index in list(range(len(args)))
]
comparison_circ = _get_gate_circuit(
PARAM_EXTRA_COMMANDS[gate],
qubit_args,
[Symbol("param" + str(index) + "/pi") for index in range(len(symbols))],
)
# checks that each command has same string
existing_op = all(
str(g) == str(c)
for g, c in zip(
gate_circ.get_commands(), comparison_circ.get_commands()
)
)
if not existing_op:
gate_circ.symbol_substitution(symbol_map)
gate_circ.rename_units(cast(Dict[UnitID, UnitID], rename_map))
self.gate_dict[gate] = {
"definition": gate_circ.to_dict(),
"args": symbols,
"name": gate,
}
opaq = gdef
def oqasm(self, tree: List) -> Any:
return Discard
def incl(self, tree: List[Token]) -> None:
self.include = str(tree[0].value).split(".")[0]
self.gate_dict.update(_load_include_module(self.include, True, False))
def prog(self, tree: Iterable) -> Dict[str, Any]:
outdict: Dict[str, Any] = {
"commands": list(
chain.from_iterable(
filter(lambda x: x is not None and x is not Discard, tree)
)
)
}
if self.return_gate_dict:
return self.gate_dict
outdict["qubits"] = [
[reg, [i]] for reg, size in self.q_registers.items() for i in range(size)
]
outdict["bits"] = [
[reg, [i]] for reg, size in self.c_registers.items() for i in range(size)
]
outdict["implicit_permutation"] = [[q, q] for q in outdict["qubits"]]
outdict["phase"] = "0.0"
self._reset_context()
return outdict
def parser(maxwidth: int) -> Lark:
return Lark(
grammar,
start="prog",
debug=False,
parser="lalr",
cache=True,
transformer=CircuitTransformer(maxwidth=maxwidth),
)
g_parser = None
g_maxwidth = 32
def set_parser(maxwidth: int) -> None:
global g_parser, g_maxwidth
if (g_parser is None) or (g_maxwidth != maxwidth): # type: ignore
g_parser = parser(maxwidth=maxwidth)
g_maxwidth = maxwidth
[docs]
def circuit_from_qasm(
input_file: Union[str, "os.PathLike[Any]"],
encoding: str = "utf-8",
maxwidth: int = 32,
) -> Circuit:
"""A method to generate a tket Circuit from a qasm file.
:param input_file: path to qasm file; filename must have ``.qasm`` extension
:param encoding: file encoding (default utf-8)
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: pytket circuit
"""
ext = os.path.splitext(input_file)[-1]
if ext != ".qasm":
raise TypeError("Can only convert .qasm files")
with open(input_file, "r", encoding=encoding) as f:
try:
circ = circuit_from_qasm_io(f, maxwidth=maxwidth)
except QASMParseError as e:
raise QASMParseError(e.msg, e.line, str(input_file))
return circ
[docs]
def circuit_from_qasm_str(qasm_str: str, maxwidth: int = 32) -> Circuit:
"""A method to generate a tket Circuit from a qasm string.
:param qasm_str: qasm string
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: pytket circuit
"""
global g_parser
set_parser(maxwidth=maxwidth)
assert g_parser is not None
cast(CircuitTransformer, g_parser.options.transformer)._reset_context(
reset_wasm=False
)
return Circuit.from_dict(g_parser.parse(qasm_str)) # type: ignore[arg-type]
[docs]
def circuit_from_qasm_io(stream_in: TextIO, maxwidth: int = 32) -> Circuit:
"""A method to generate a tket Circuit from a qasm text stream"""
return circuit_from_qasm_str(stream_in.read(), maxwidth=maxwidth)
[docs]
def circuit_from_qasm_wasm(
input_file: Union[str, "os.PathLike[Any]"],
wasm_file: Union[str, "os.PathLike[Any]"],
encoding: str = "utf-8",
maxwidth: int = 32,
) -> Circuit:
"""A method to generate a tket Circuit from a qasm string and external WASM module.
:param input_file: path to qasm file; filename must have ``.qasm`` extension
:param wasm_file: path to WASM file containing functions used in qasm
:param encoding: encoding of qasm file (default utf-8)
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: pytket circuit
"""
global g_parser
wasm_module = WasmFileHandler(str(wasm_file))
set_parser(maxwidth=maxwidth)
assert g_parser is not None
cast(CircuitTransformer, g_parser.options.transformer).wasm = wasm_module
return circuit_from_qasm(input_file, encoding=encoding, maxwidth=maxwidth)
[docs]
def circuit_to_qasm(
circ: Circuit, output_file: str, header: str = "qelib1", maxwidth: int = 32
) -> None:
"""Convert a Circuit to QASM and write it to a file.
Classical bits in the pytket circuit must be singly-indexed.
Note that this will not account for implicit qubit permutations in the Circuit.
:param circ: pytket circuit
:param output_file: path to output qasm file
:param header: qasm header (default "qelib1")
:param maxwidth: maximum allowed width of classical registers (default 32)
"""
with open(output_file, "w") as out:
circuit_to_qasm_io(circ, out, header=header, maxwidth=maxwidth)
def _filtered_qasm_str(qasm: str) -> str:
# remove any c registers starting with _TEMP_BIT_NAME
# that are not being used somewhere else
lines = qasm.split("\n")
def_matcher = re.compile(r"creg ({}\_*\d*)\[\d+\]".format(_TEMP_BIT_NAME))
arg_matcher = re.compile(r"({}\_*\d*)\[\d+\]".format(_TEMP_BIT_NAME))
unused_regs = dict()
for i, line in enumerate(lines):
if reg := def_matcher.match(line):
# Mark a reg temporarily as unused
unused_regs[reg.group(1)] = i
elif args := arg_matcher.findall(line):
# If the line contains scratch bits that are used as arguments
# mark these regs as used
for arg in args:
if arg in unused_regs:
unused_regs.pop(arg)
# remove unused reg defs
redundant_lines = sorted(unused_regs.values(), reverse=True)
for line_index in redundant_lines:
del lines[line_index]
return "\n".join(lines)
def is_empty_customgate(op: Op) -> bool:
return op.type == OpType.CustomGate and op.get_circuit().n_gates == 0 # type: ignore
def check_can_convert_circuit(circ: Circuit, header: str, maxwidth: int) -> None:
if any(
circ.n_gates_of_type(typ)
for typ in (
OpType.RangePredicate,
OpType.MultiBit,
OpType.ExplicitPredicate,
OpType.ExplicitModifier,
OpType.SetBits,
OpType.CopyBits,
OpType.ClassicalExpBox,
)
) and (not hqs_header(header)):
raise QASMUnsupportedError(
"Complex classical gates not supported with qelib1: try converting with "
"`header=hqslib1`"
)
if any(bit.index[0] >= maxwidth for bit in circ.bits):
raise QASMUnsupportedError(
f"Circuit contains a classical register larger than {maxwidth}: try "
"setting the `maxwidth` parameter to a higher value."
)
for cmd in circ:
if is_empty_customgate(cmd.op) or (
isinstance(cmd.op, Conditional) and is_empty_customgate(cmd.op.op)
):
raise QASMUnsupportedError(
f"Empty CustomGates and opaque gates are not supported."
)
[docs]
def circuit_to_qasm_str(
circ: Circuit,
header: str = "qelib1",
include_gate_defs: Optional[Set[str]] = None,
maxwidth: int = 32,
) -> str:
"""Convert a Circuit to QASM and return the string.
Classical bits in the pytket circuit must be singly-indexed.
Note that this will not account for implicit qubit permutations in the Circuit.
:param circ: pytket circuit
:param header: qasm header (default "qelib1")
:param output_file: path to output qasm file
:param include_gate_defs: optional set of gates to include
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: qasm string
"""
check_can_convert_circuit(circ, header, maxwidth)
qasm_writer = QasmWriter(
circ.qubits, circ.bits, header, include_gate_defs, maxwidth
)
circ1 = circ.copy()
DecomposeBoxes().apply(circ1)
for command in circ1:
assert isinstance(command, Command)
qasm_writer.add_op(command.op, command.args)
return qasm_writer.finalize()
TypeReg = TypeVar("TypeReg", BitRegister, QubitRegister)
def _retrieve_registers(
units: List[UnitID], reg_type: Type[TypeReg]
) -> Dict[str, TypeReg]:
if any(len(unit.index) != 1 for unit in units):
raise NotImplementedError("OPENQASM registers must use a single index")
maxunits = map(lambda x: max(x[1]), groupby(units, key=lambda un: un.reg_name))
return {
maxunit.reg_name: reg_type(maxunit.reg_name, maxunit.index[0] + 1)
for maxunit in maxunits
}
def _parse_range(minval: int, maxval: int, maxwidth: int) -> Tuple[str, int]:
if maxwidth > 64:
raise NotImplementedError("Register width exceeds maximum of 64.")
REGMAX = (1 << maxwidth) - 1
if minval > REGMAX:
raise NotImplementedError("Range's lower bound exceeds register capacity.")
elif minval > maxval:
raise NotImplementedError("Range's lower bound exceeds upper bound.")
elif maxval > REGMAX:
maxval = REGMAX
if minval == maxval:
return ("==", minval)
elif minval == 0:
return ("<=", maxval)
elif maxval == REGMAX:
return (">=", minval)
else:
raise NotImplementedError("Range can only be bounded on one side.")
def _negate_comparator(comparator: str) -> str:
if comparator == "==":
return "!="
elif comparator == "!=":
return "=="
elif comparator == "<=":
return ">"
elif comparator == ">":
return "<="
elif comparator == ">=":
return "<"
else:
assert comparator == "<"
return ">="
def _get_optype_and_params(op: Op) -> Tuple[OpType, Optional[List[Union[float, Expr]]]]:
optype = op.type
params = (
op.params
if (optype in _tk_to_qasm_params) or (optype in _tk_to_qasm_extra_params)
else None
)
if optype == OpType.TK1:
# convert to U3
optype = OpType.U3
params = [op.params[1], op.params[0] - 0.5, op.params[2] + 0.5]
elif optype == OpType.CustomGate:
params = op.params
return optype, params
def _get_gate_circuit(
optype: OpType, qubits: List[Qubit], symbols: Optional[List[Symbol]] = None
) -> Circuit:
# create Circuit for constructing qasm from
unitids = cast(List[UnitID], qubits)
gate_circ = Circuit()
for q in qubits:
gate_circ.add_qubit(q)
if symbols:
exprs = [symbol.as_expr() for symbol in symbols]
gate_circ.add_gate(optype, exprs, unitids)
else:
gate_circ.add_gate(optype, unitids)
AutoRebase({OpType.CX, OpType.U3}).apply(gate_circ)
RemoveRedundancies().apply(gate_circ)
return gate_circ
def hqs_header(header: str) -> bool:
return header in ["hqslib1", "hqslib1_dev"]
@dataclass
class ConditionString:
variable: str # variable, e.g. "c[1]"
comparator: str # comparator, e.g. "=="
value: int # value, e.g. "1"
class LabelledStringList:
"""
Wrapper class for an ordered sequence of strings, where each string has a unique
label, returned when the string is added, and a string may be removed from the
sequence given its label. There is a method to retrieve the concatenation of all
strings in order. The conditions (e.g. "if(c[0]==1)") for some strings are stored
separately in `conditions`. These conditions will be converted to text when
retrieving the full string.
"""
def __init__(self) -> None:
self.strings: OrderedDict[int, str] = OrderedDict()
self.conditions: Dict[int, ConditionString] = dict()
self.label = 0
def add_string(self, string: str) -> int:
label = self.label
self.strings[label] = string
self.label += 1
return label
def get_string(self, label: int) -> Optional[str]:
return self.strings.get(label, None)
def del_string(self, label: int) -> None:
self.strings.pop(label, None)
def get_full_string(self) -> str:
strings = []
for l, s in self.strings.items():
condition = self.conditions.get(l)
if condition is not None:
strings.append(
f"if({condition.variable}{condition.comparator}{condition.value}) "
+ s
)
else:
strings.append(s)
return "".join(strings)
def make_params_str(params: Optional[List[Union[float, Expr]]]) -> str:
s = ""
if params is not None:
n_params = len(params)
s += "("
for i in range(n_params):
reduced = True
try:
p: Union[float, Expr] = float(params[i])
except TypeError:
reduced = False
p = params[i]
if i < n_params - 1:
if reduced:
s += "{}*pi,".format(p)
else:
s += "({})*pi,".format(p)
else:
if reduced:
s += "{}*pi)".format(p)
else:
s += "({})*pi)".format(p)
s += " "
return s
def make_args_str(args: Sequence[UnitID]) -> str:
s = ""
for i in range(len(args)):
s += f"{args[i]}"
if i < len(args) - 1:
s += ","
else:
s += ";\n"
return s
@dataclass
class ScratchPredicate:
variable: str # variable, e.g. "c[1]"
comparator: str # comparator, e.g. "=="
value: int # value, e.g. "1"
dest: str # destination bit, e.g. "tk_SCRATCH_BIT[0]"
def _vars_overlap(v: str, w: str) -> bool:
"""check if two variables have overlapping bits"""
v_split = v.split("[")
w_split = w.split("[")
if v_split[0] != w_split[0]:
# different registers
return False
# e.g. (a[1], a), (a, a[1]), (a[1], a[1]), (a, a)
return len(v_split) != len(w_split) or v == w
def _var_appears(v: str, s: str) -> bool:
"""check if variable v appears in string s"""
v_split = v.split("[")
if len(v_split) == 1:
# check if v appears in s and is not surrounded by word characters
# e.g. a = a & b or a = a[1] & b[1]
return bool(re.search(r"(?<!\w)" + re.escape(v) + r"(?![\w])", s))
else:
if re.search(r"(?<!\w)" + re.escape(v), s):
# check if v appears in s and is not proceeded by word characters
# e.g. a[1] = a[1]
return True
# check the register of v appears in s
# e.g. a[1] = a & b
return bool(re.search(r"(?<!\w)" + re.escape(v_split[0]) + r"(?![\[\w])", s))
class QasmWriter:
"""
Helper class for converting a sequence of TKET Commands to QASM, and retrieving the
final QASM string afterwards.
"""
def __init__(
self,
qubits: List[Qubit],
bits: List[Bit],
header: str = "qelib1",
include_gate_defs: Optional[Set[str]] = None,
maxwidth: int = 32,
):
self.header = header
self.maxwidth = maxwidth
self.added_gate_definitions: Set[str] = set()
self.include_module_gates = {"measure", "reset", "barrier"}
self.include_module_gates.update(
_load_include_module(header, False, True).keys()
)
self.prefix = ""
self.gatedefs = ""
self.strings = LabelledStringList()
# Record of `RangePredicate` operations that set a "scratch" bit to 0 or 1
# depending on the value of the predicate. This map is consulted when we
# encounter a `Conditional` operation to see if the condition bit is one of
# these scratch bits, which we can then replace with the original.
self.range_preds: Dict[int, ScratchPredicate] = dict()
if include_gate_defs is None:
self.include_gate_defs = self.include_module_gates
self.include_gate_defs.update(NOPARAM_EXTRA_COMMANDS.keys())
self.include_gate_defs.update(PARAM_EXTRA_COMMANDS.keys())
self.prefix = 'OPENQASM 2.0;\ninclude "{}.inc";\n\n'.format(header)
self.qregs = _retrieve_registers(cast(list[UnitID], qubits), QubitRegister)
self.cregs = _retrieve_registers(cast(list[UnitID], bits), BitRegister)
for reg in self.qregs.values():
if regname_regex.match(reg.name) is None:
raise QASMUnsupportedError(
f"Invalid register name '{reg.name}'. QASM register names must "
"begin with a lowercase letter and may only contain lowercase "
"and uppercase letters, numbers, and underscores. "
"Try renaming the register with `rename_units` first."
)
for bit_reg in self.cregs.values():
if regname_regex.match(bit_reg.name) is None:
raise QASMUnsupportedError(
f"Invalid register name '{bit_reg.name}'. QASM register names "
"must begin with a lowercase letter and may only contain "
"lowercase and uppercase letters, numbers, and underscores. "
"Try renaming the register with `rename_units` first."
)
else:
# gate definition, no header necessary for file
self.include_gate_defs = include_gate_defs
self.cregs = {}
self.qregs = {}
self.cregs_as_bitseqs = set(tuple(creg) for creg in self.cregs.values())
# for holding condition values when writing Conditional blocks
# the size changes when adding and removing scratch bits
self.scratch_reg = BitRegister(
next(
f"{_TEMP_BIT_REG_BASE}_{i}"
for i in itertools.count()
if f"{_TEMP_BIT_REG_BASE}_{i}" not in self.qregs
),
0,
)
# if a string writes to some classical variables, the string label and
# the affected variables will be recorded.
self.variable_writes: Dict[int, List[str]] = dict()
def fresh_scratch_bit(self) -> Bit:
self.scratch_reg = BitRegister(self.scratch_reg.name, self.scratch_reg.size + 1)
return Bit(self.scratch_reg.name, self.scratch_reg.size - 1)
def remove_last_scratch_bit(self) -> None:
assert self.scratch_reg.size > 0
self.scratch_reg = BitRegister(self.scratch_reg.name, self.scratch_reg.size - 1)
def write_params(self, params: Optional[List[Union[float, Expr]]]) -> None:
params_str = make_params_str(params)
self.strings.add_string(params_str)
def write_args(self, args: Sequence[UnitID]) -> None:
args_str = make_args_str(args)
self.strings.add_string(args_str)
def make_gate_definition(
self,
n_qubits: int,
opstr: str,
optype: OpType,
n_params: Optional[int] = None,
) -> str:
s = "gate " + opstr + " "
symbols: Optional[List[Symbol]] = None
if n_params is not None:
# need to add parameters to gate definition
s += "("
symbols = [
Symbol("param" + str(index) + "/pi") for index in range(n_params)
]
symbols_header = [Symbol("param" + str(index)) for index in range(n_params)]
for symbol in symbols_header[:-1]:
s += symbol.name + ", "
s += symbols_header[-1].name + ") "
# add qubits to gate definition
qubit_args = [
Qubit(opstr + "q" + str(index)) for index in list(range(n_qubits))
]
for qb in qubit_args[:-1]:
s += str(qb) + ","
s += str(qubit_args[-1]) + " {\n"
# get rebased circuit for constructing qasm
gate_circ = _get_gate_circuit(optype, qubit_args, symbols)
# write circuit to qasm
s += circuit_to_qasm_str(
gate_circ, self.header, self.include_gate_defs, self.maxwidth
)
s += "}\n"
return s
def mark_as_written(self, label: int, written_variable: str) -> None:
if label in self.variable_writes:
self.variable_writes[label].append(written_variable)
else:
self.variable_writes[label] = [written_variable]
def add_range_predicate(self, op: RangePredicateOp, args: List[Bit]) -> None:
comparator, value = _parse_range(op.lower, op.upper, self.maxwidth)
if (not hqs_header(self.header)) and comparator != "==":
raise QASMUnsupportedError(
"OpenQASM conditions must be on a register's fixed value."
)
bits = args[:-1]
variable = args[0].reg_name
dest_bit = str(args[-1])
if not hqs_header(self.header):
assert isinstance(variable, str)
if op.n_inputs != self.cregs[variable].size:
raise QASMUnsupportedError(
"OpenQASM conditions must be an entire classical register"
)
if bits != self.cregs[variable].to_list():
raise QASMUnsupportedError(
"OpenQASM conditions must be a single classical register"
)
label = self.strings.add_string(
"".join(
[
f"if({variable}{comparator}{value}) " + f"{dest_bit} = 1;\n",
f"if({variable}{_negate_comparator(comparator)}{value}) "
+ f"{dest_bit} = 0;\n",
]
)
)
# Record this operation.
# Later if we find a conditional based on dest_bit, we can replace dest_bit with
# (variable, comparator, value), provided that variable hasn't been written to
# in the mean time. (So we must watch for that, and remove the record from the
# list if it is.)
# Note that we only perform such rewrites for internal scratch bits.
if dest_bit.startswith(_TEMP_BIT_NAME):
self.range_preds[label] = ScratchPredicate(
variable, comparator, value, dest_bit
)
def replace_condition(self, pred_label: int) -> bool:
"""Given the label of a predicate p=(var, comp, value, dest, label)
we scan the lines after p:
1.if dest is the condition of a conditional line we replace dest with
the predicate and do 2 for the inner command.
2.if either the variable or the dest gets written, we stop.
returns true if a replacement is made.
"""
assert pred_label in self.range_preds
success = False
pred = self.range_preds[pred_label]
line_labels = []
for label in range(pred_label + 1, self.strings.label):
string = self.strings.get_string(label)
if string is None:
continue
line_labels.append(label)
if "\n" not in string:
continue
written_variables: List[str] = []
# (label, condition)
conditions: List[Tuple[int, ConditionString]] = []
for l in line_labels:
written_variables.extend(self.variable_writes.get(l, []))
cond = self.strings.conditions.get(l)
if cond:
conditions.append((l, cond))
if len(conditions) == 1 and pred.dest == conditions[0][1].variable:
# if the condition is dest, replace the condition with pred
success = True
if conditions[0][1].value == 1:
self.strings.conditions[conditions[0][0]] = ConditionString(
pred.variable, pred.comparator, pred.value
)
else:
assert conditions[0][1].value == 0
self.strings.conditions[conditions[0][0]] = ConditionString(
pred.variable,
_negate_comparator(pred.comparator),
pred.value,
)
if any(_vars_overlap(pred.dest, v) for v in written_variables) or any(
_vars_overlap(pred.variable, v) for v in written_variables
):
return success
line_labels.clear()
conditions.clear()
written_variables.clear()
return success
def remove_unused_predicate(self, pred_label: int) -> bool:
"""Given the label of a predicate p=(var, comp, value, dest, label),
we remove p if dest never appears after p."""
assert pred_label in self.range_preds
pred = self.range_preds[pred_label]
for label in range(pred_label + 1, self.strings.label):
string = self.strings.get_string(label)
if string is None:
continue
if (
_var_appears(pred.dest, string)
or label in self.strings.conditions
and _vars_overlap(pred.dest, self.strings.conditions[label].variable)
):
return False
self.range_preds.pop(pred_label)
self.strings.del_string(pred_label)
return True
def add_conditional(self, op: Conditional, args: Sequence[UnitID]) -> None:
control_bits = args[: op.width]
if op.width == 1 and hqs_header(self.header):
variable = str(control_bits[0])
else:
variable = control_bits[0].reg_name
if (
hqs_header(self.header)
and control_bits != self.cregs[variable].to_list()
):
raise QASMUnsupportedError(
"hqslib1 QASM conditions must be an entire classical "
"register or a single bit"
)
if not hqs_header(self.header):
if op.width != self.cregs[variable].size:
raise QASMUnsupportedError(
"OpenQASM conditions must be an entire classical register"
)
if control_bits != self.cregs[variable].to_list():
raise QASMUnsupportedError(
"OpenQASM conditions must be a single classical register"
)
if op.op.type == OpType.Phase:
# Conditional phase is ignored.
return
if op.op.type == OpType.RangePredicate:
raise QASMUnsupportedError(
"Conditional RangePredicate is currently unsupported."
)
# we assign the condition to a scratch bit, which we will later remove
# if the condition variable is unchanged.
scratch_bit = self.fresh_scratch_bit()
pred_label = self.strings.add_string(
f"if({variable}=={op.value}) " + f"{scratch_bit} = 1;\n"
)
self.range_preds[pred_label] = ScratchPredicate(
variable, "==", op.value, str(scratch_bit)
)
# we will later add condition to all lines starting from next_label
next_label = self.strings.label
self.add_op(op.op, args[op.width :])
# add conditions to the lines after the predicate
is_new_line = True
for label in range(next_label, self.strings.label):
string = self.strings.get_string(label)
assert string is not None
if is_new_line and string != "\n":
self.strings.conditions[label] = ConditionString(
str(scratch_bit), "==", 1
)
is_new_line = "\n" in string
if self.replace_condition(pred_label) and self.remove_unused_predicate(
pred_label
):
# remove the unused scratch bit
self.remove_last_scratch_bit()
def add_set_bits(self, op: SetBitsOp, args: List[Bit]) -> None:
creg_name = args[0].reg_name
bits, vals = zip(*sorted(zip(args, op.values)))
# check if whole register can be set at once
if bits == tuple(self.cregs[creg_name].to_list()):
value = int("".join(map(str, map(int, vals[::-1]))), 2)
label = self.strings.add_string(f"{creg_name} = {value};\n")
self.mark_as_written(label, f"{creg_name}")
else:
for bit, value in zip(bits, vals):
label = self.strings.add_string(f"{bit} = {int(value)};\n")
self.mark_as_written(label, f"{bit}")
def add_copy_bits(self, op: CopyBitsOp, args: List[Bit]) -> None:
l_args = args[op.n_inputs :]
r_args = args[: op.n_inputs]
l_name = l_args[0].reg_name
r_name = r_args[0].reg_name
# check if whole register can be set at once
if (
l_args == self.cregs[l_name].to_list()
and r_args == self.cregs[r_name].to_list()
):
label = self.strings.add_string(f"{l_name} = {r_name};\n")
self.mark_as_written(label, f"{l_name}")
else:
for bit_l, bit_r in zip(l_args, r_args):
label = self.strings.add_string(f"{bit_l} = {bit_r};\n")
self.mark_as_written(label, f"{bit_l}")
def add_multi_bit(self, op: MultiBitOp, args: List[Bit]) -> None:
basic_op = op.basic_op
basic_n = basic_op.n_inputs + basic_op.n_outputs + basic_op.n_input_outputs
n_args = len(args)
assert n_args % basic_n == 0
arity = n_args // basic_n
# If the operation is register-aligned we can write it more succinctly.
poss_regs = [
tuple(args[basic_n * i + j] for i in range(arity)) for j in range(basic_n)
]
if all(poss_reg in self.cregs_as_bitseqs for poss_reg in poss_regs):
# The operation is register-aligned.
self.add_op(basic_op, [poss_regs[j][0].reg_name for j in range(basic_n)]) # type: ignore
else:
# The operation is not register-aligned.
for i in range(arity):
basic_args = args[basic_n * i : basic_n * (i + 1)]
self.add_op(basic_op, basic_args)
def add_explicit_op(self, op: Op, args: List[Bit]) -> None:
# &, ^ and | gates
opstr = str(op)
if opstr not in _classical_gatestr_map:
raise QASMUnsupportedError(f"Classical gate {opstr} not supported.")
label = self.strings.add_string(
f"{args[-1]} = {args[0]} {_classical_gatestr_map[opstr]} {args[1]};\n"
)
self.mark_as_written(label, f"{args[-1]}")
def add_classical_exp_box(self, op: ClassicalExpBox, args: List[Bit]) -> None:
out_args = args[op.get_n_i() :]
if len(out_args) == 1:
label = self.strings.add_string(f"{out_args[0]} = {str(op.get_exp())};\n")
self.mark_as_written(label, f"{out_args[0]}")
elif (
out_args
== self.cregs[out_args[0].reg_name].to_list()[
: op.get_n_io() + op.get_n_o()
]
):
label = self.strings.add_string(
f"{out_args[0].reg_name} = {str(op.get_exp())};\n"
)
self.mark_as_written(label, f"{out_args[0].reg_name}")
else:
raise QASMUnsupportedError(
f"ClassicalExpBox only supported"
" for writing to a single bit or whole registers."
)
def add_wasm(self, op: WASMOp, args: List[Bit]) -> None:
inputs: List[str] = []
outputs: List[str] = []
for reglist, sizes in [(inputs, op.input_widths), (outputs, op.output_widths)]:
for in_width in sizes:
bits = args[:in_width]
args = args[in_width:]
regname = bits[0].reg_name
if bits != list(self.cregs[regname]):
QASMUnsupportedError("WASM ops must act on entire registers.")
reglist.append(regname)
if outputs:
label = self.strings.add_string(f"{', '.join(outputs)} = ")
self.strings.add_string(f"{op.func_name}({', '.join(inputs)});\n")
for variable in outputs:
self.mark_as_written(label, variable)
def add_measure(self, args: Sequence[UnitID]) -> None:
label = self.strings.add_string(f"measure {args[0]} -> {args[1]};\n")
self.mark_as_written(label, f"{args[1]}")
def add_zzphase(self, param: Union[float, Expr], args: Sequence[UnitID]) -> None:
# as op.params returns reduced parameters, we can assume
# that 0 <= param < 4
if param > 1:
# first get in to 0 <= param < 2 range
param = Decimal(str(param)) % Decimal("2")
# then flip 1 <= param < 2 range into
# -1 <= param < 0
if param > 1:
param = -2 + param
self.strings.add_string("RZZ")
self.write_params([param])
self.write_args(args)
def add_data(self, op: BarrierOp, args: Sequence[UnitID]) -> None:
if op.data == "":
opstr = _tk_to_qasm_noparams[OpType.Barrier]
else:
opstr = op.data
self.strings.add_string(opstr)
self.strings.add_string(" ")
self.write_args(args)
def add_gate_noparams(self, op: Op, args: Sequence[UnitID]) -> None:
self.strings.add_string(_tk_to_qasm_noparams[op.type])
self.strings.add_string(" ")
self.write_args(args)
def add_gate_params(self, op: Op, args: Sequence[UnitID]) -> None:
optype, params = _get_optype_and_params(op)
self.strings.add_string(_tk_to_qasm_params[optype])
self.write_params(params)
self.write_args(args)
def add_extra_noparams(self, op: Op, args: Sequence[UnitID]) -> Tuple[str, str]:
optype = op.type
opstr = _tk_to_qasm_extra_noparams[optype]
gatedefstr = ""
if opstr not in self.added_gate_definitions:
self.added_gate_definitions.add(opstr)
gatedefstr = self.make_gate_definition(op.n_qubits, opstr, optype)
mainstr = opstr + " " + make_args_str(args)
return gatedefstr, mainstr
def add_extra_params(self, op: Op, args: Sequence[UnitID]) -> Tuple[str, str]:
optype, params = _get_optype_and_params(op)
assert params is not None
opstr = _tk_to_qasm_extra_params[optype]
gatedefstr = ""
if opstr not in self.added_gate_definitions:
self.added_gate_definitions.add(opstr)
gatedefstr = self.make_gate_definition(
op.n_qubits, opstr, optype, len(params)
)
mainstr = opstr + make_params_str(params) + make_args_str(args)
return gatedefstr, mainstr
def add_op(self, op: Op, args: Sequence[UnitID]) -> None:
optype, _params = _get_optype_and_params(op)
if optype == OpType.RangePredicate:
assert isinstance(op, RangePredicateOp)
self.add_range_predicate(op, cast(List[Bit], args))
elif optype == OpType.Conditional:
assert isinstance(op, Conditional)
self.add_conditional(op, args)
elif optype == OpType.Phase:
# global phase is ignored in QASM
pass
elif optype == OpType.SetBits:
assert isinstance(op, SetBitsOp)
self.add_set_bits(op, cast(List[Bit], args))
elif optype == OpType.CopyBits:
assert isinstance(op, CopyBitsOp)
self.add_copy_bits(op, cast(List[Bit], args))
elif optype == OpType.MultiBit:
assert isinstance(op, MultiBitOp)
self.add_multi_bit(op, cast(List[Bit], args))
elif optype in (OpType.ExplicitPredicate, OpType.ExplicitModifier):
self.add_explicit_op(op, cast(List[Bit], args))
elif optype == OpType.ClassicalExpBox:
assert isinstance(op, ClassicalExpBox)
self.add_classical_exp_box(op, cast(List[Bit], args))
elif optype == OpType.WASM:
assert isinstance(op, WASMOp)
self.add_wasm(op, cast(List[Bit], args))
elif optype == OpType.Measure:
self.add_measure(args)
elif hqs_header(self.header) and optype == OpType.ZZPhase:
# special handling for zzphase
assert len(op.params) == 1
self.add_zzphase(op.params[0], args)
elif optype == OpType.Barrier and self.header == "hqslib1_dev":
assert isinstance(op, BarrierOp)
self.add_data(op, args)
elif (
optype in _tk_to_qasm_noparams
and _tk_to_qasm_noparams[optype] in self.include_module_gates
):
self.add_gate_noparams(op, args)
elif (
optype in _tk_to_qasm_params
and _tk_to_qasm_params[optype] in self.include_module_gates
):
self.add_gate_params(op, args)
elif optype in _tk_to_qasm_extra_noparams:
gatedefstr, mainstr = self.add_extra_noparams(op, args)
self.gatedefs += gatedefstr
self.strings.add_string(mainstr)
elif optype in _tk_to_qasm_extra_params:
gatedefstr, mainstr = self.add_extra_params(op, args)
self.gatedefs += gatedefstr
self.strings.add_string(mainstr)
else:
raise QASMUnsupportedError(
"Cannot print command of type: {}".format(op.get_name())
)
def finalize(self) -> str:
# try removing unused predicates
pred_labels = list(self.range_preds.keys())
for label in pred_labels:
# try replacing conditions with a predicate
self.replace_condition(label)
# try removing the predicate
self.remove_unused_predicate(label)
reg_strings = LabelledStringList()
for reg in self.qregs.values():
reg_strings.add_string(f"qreg {reg.name}[{reg.size}];\n")
for bit_reg in self.cregs.values():
reg_strings.add_string(f"creg {bit_reg.name}[{bit_reg.size}];\n")
if self.scratch_reg.size > 0:
reg_strings.add_string(
f"creg {self.scratch_reg.name}[{self.scratch_reg.size}];\n"
)
return (
self.prefix
+ self.gatedefs
+ _filtered_qasm_str(
reg_strings.get_full_string() + self.strings.get_full_string()
)
)
[docs]
def circuit_to_qasm_io(
circ: Circuit,
stream_out: TextIO,
header: str = "qelib1",
include_gate_defs: Optional[Set[str]] = None,
maxwidth: int = 32,
) -> None:
"""Convert a Circuit to QASM and write to a text stream.
Classical bits in the pytket circuit must be singly-indexed.
Note that this will not account for implicit qubit permutations in the Circuit.
:param circ: pytket circuit
:param stream_out: text stream to be written to
:param header: qasm header (default "qelib1")
:param include_gate_defs: optional set of gates to include
:param maxwidth: maximum allowed width of classical registers (default 32)
"""
stream_out.write(
circuit_to_qasm_str(
circ, header=header, include_gate_defs=include_gate_defs, maxwidth=maxwidth
)
)