import ast
import builtins
import inspect
from collections.abc import Callable, Sequence
from types import FrameType
from typing import Any, ParamSpec, TypeVar, cast
from guppylang_internals.ast_util import annotate_location
from guppylang_internals.compiler.core import (
CompilerContext,
)
from guppylang_internals.decorator import (
custom_function,
custom_type,
hugr_op,
)
from guppylang_internals.definition.common import DefId
from guppylang_internals.definition.const import RawConstDef
from guppylang_internals.definition.custom import (
CustomCallChecker,
CustomInoutCallCompiler,
RawCustomFunctionDef,
)
from guppylang_internals.definition.declaration import RawFunctionDecl
from guppylang_internals.definition.extern import RawExternDef
from guppylang_internals.definition.function import (
RawFunctionDef,
)
from guppylang_internals.definition.overloaded import OverloadedFunctionDef
from guppylang_internals.definition.parameter import (
ConstVarDef,
RawConstVarDef,
TypeVarDef,
)
from guppylang_internals.definition.pytket_circuits import (
RawLoadPytketDef,
RawPytketDef,
)
from guppylang_internals.definition.struct import RawStructDef
from guppylang_internals.definition.traced import RawTracedFunctionDef
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.dummy_decorator import _DummyGuppy, sphinx_running
from guppylang_internals.engine import DEF_STORE
from guppylang_internals.span import Loc, SourceMap, Span
from guppylang_internals.tys.arg import Argument
from guppylang_internals.tys.param import Parameter
from guppylang_internals.tys.subst import Inst
from guppylang_internals.tys.ty import (
FunctionType,
NoneType,
NumericType,
)
from hugr import ops
from hugr import tys as ht
from hugr import val as hv
from hugr.package import ModulePointer
from typing_extensions import dataclass_transform, deprecated
from guppylang.defs import (
GuppyDefinition,
GuppyFunctionDefinition,
GuppyTypeVarDefinition,
)
S = TypeVar("S")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
Decorator = Callable[[S], T]
AnyRawFunctionDef = (
RawFunctionDef,
RawCustomFunctionDef,
RawFunctionDecl,
RawPytketDef,
RawLoadPytketDef,
OverloadedFunctionDef,
)
__all__ = ("guppy", "custom_guppy_decorator")
class _Guppy:
"""Class for the `@guppy` decorator."""
def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
def comptime(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
"""Registers a function to be executed at compile-time during Guppy compilation,
enabling the use of arbitrary Python features as long as they don't depend on
runtime values.
.. code-block:: python
from guppylang import guppy
from guppylang.std.builtins import array
@guppy.comptime
def print_arrays(arr1: array[str, 10], arr2: array[str, 10]) -> None:
for s1, s2 in zip(arr1, arr2):
print(f"({s1}, {s2})")
"""
defn = RawTracedFunctionDef(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
@deprecated("Use @guppylang_internal.decorator.extend_type instead.")
def extend_type(self, defn: TypeDef) -> Callable[[type], type]:
# Set `return_class=True` to match the old behaviour until this deprecated
# method is removed
import guppylang_internals.decorator
return guppylang_internals.decorator.extend_type(defn, return_class=True)
@deprecated("Use @guppylang_internal.decorator.custom_type instead.")
def type(
self,
hugr_ty: ht.Type | Callable[[Sequence[Argument], CompilerContext], ht.Type],
name: str = "",
copyable: bool = True,
droppable: bool = True,
bound: ht.TypeBound | None = None,
params: Sequence[Parameter] | None = None,
) -> Callable[[type[T]], type[T]]:
return custom_type(hugr_ty, name, copyable, droppable, bound, params)
@dataclass_transform()
def struct(self, cls: builtins.type[T]) -> builtins.type[T]:
"""Registers a class as a Guppy struct.
.. code-block:: python
from guppylang import guppy
@guppy.struct
class MyStruct:
field1: int
field2: int
@guppy
def add_fields(self: "MyStruct") -> int:
return self.field2 + self.field2
"""
defn = RawStructDef(DefId.fresh(), cls.__name__, None, cls)
frame = get_calling_frame()
DEF_STORE.register_def(defn, frame)
for val in cls.__dict__.values():
if isinstance(val, GuppyDefinition):
DEF_STORE.register_impl(defn.id, val.wrapped.name, val.id)
# Prior to Python 3.13, the `__firstlineno__` attribute on classes is not set.
# However, we need this information to precisely look up the source for the
# class later. If it's not there, we can set it from the calling frame:
if not hasattr(cls, "__firstlineno__"):
cls.__firstlineno__ = frame.f_lineno # type: ignore[attr-defined]
# We're pretending to return the class unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
def type_var(
self,
name: str,
copyable: bool = True,
droppable: bool = True,
) -> TypeVar:
"""Creates a new type variable.
.. code-block:: python
from guppylang import guppy
T = guppy.type_var("T")
@guppy
def identity(x: T) -> T:
return x
"""
defn = TypeVarDef(DefId.fresh(), name, None, copyable, droppable)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
def nat_var(self, name: str) -> TypeVar:
"""Creates a new nat variable."""
defn = ConstVarDef(DefId.fresh(), name, None, NumericType(NumericType.Kind.Nat))
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
def const_var(self, name: str, ty: str) -> TypeVar:
"""Creates a new const type variable."""
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources
)
defn = RawConstVarDef(DefId.fresh(), name, None, type_ast)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
@deprecated("Use @guppylang_internal.decorator.custom_function instead.")
def custom(
self,
compiler: CustomInoutCallCompiler | None = None,
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
signature: FunctionType | None = None,
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
return custom_function(compiler, checker, higher_order_value, name, signature)
@deprecated("Use @guppylang_internal.decorator.hugr_op instead.")
def hugr_op(
self,
op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp],
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
signature: FunctionType | None = None,
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
return hugr_op(op, checker, higher_order_value, name, signature)
def declare(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
"""Declares a Guppy function without defining it."""
defn = RawFunctionDecl(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
def overload(
self, *funcs: Any
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Collects multiple function definitions into one overloaded function.
Consider the following example:
.. code-block:: python
@guppy.declare
def variant1(x: int, y: int) -> int: ...
@guppy.declare
def variant2(x: float) -> int: ...
@guppy.overload(variant1, variant2)
def combined(): ...
Now, `combined` may be called with either one `float` or two `int` arguments,
delegating to the implementation with the matching signature:
.. code-block:: python
combined(4.2) # Calls `variant1`
combined(42, 43) # Calls `variant2`
Note that the compiler will pick the *first* implementation with matching
signature and ignore all following ones, even if they would also match. For
example, if we added a third variant
.. code-block:: python
@guppy.declare
def variant3(x: int) -> int: ...
@guppy.overload(variant1, variant2, variant3)
def combined_new(): ...
then a call `combined_new(42)` will still select the `variant1` implementation
`42` is a valid argument for `variant1` and `variant1` comes before `variant3`
in the `@guppy.overload` annotation.
"""
funcs = list(funcs)
if len(funcs) < 2:
raise ValueError("Overload requires at least two functions")
func_ids = []
for func in funcs:
if not isinstance(func, GuppyDefinition):
raise TypeError(f"Not a Guppy definition: {func}")
if not isinstance(func.wrapped, AnyRawFunctionDef):
raise TypeError(
f"Not a Guppy function definition: {func.wrapped.description} "
f"`{func.wrapped.name}`"
)
func_ids.append(func.id)
def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
dummy_sig = FunctionType([], NoneType())
defn = OverloadedFunctionDef(
DefId.fresh(), f.__name__, None, dummy_sig, func_ids
)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
return dec
def constant(self, name: str, ty: str, value: hv.Value) -> T: # type: ignore[type-var] # Since we're returning a free type variable
"""Adds a constant to a module, backed by a `hugr.val.Value`."""
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources
)
defn = RawConstDef(DefId.fresh(), name, None, type_ast, value)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a free type variable, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
def _extern(
self,
name: str,
ty: str,
symbol: str | None = None,
constant: bool = True,
) -> T: # type: ignore[type-var] # Since we're returning a free type variable
"""Adds an extern symbol to a module."""
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources
)
defn = RawExternDef(
DefId.fresh(), name, None, symbol or name, constant, type_ast
)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a free type variable, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
@deprecated(
"guppy.compile(foo) is deprecated and will be removed in a future version:"
" use foo.compile() instead."
)
def compile(self, obj: Any) -> ModulePointer:
"""Compiles a Guppy definition to Hugr."""
if not isinstance(obj, GuppyDefinition):
raise TypeError(f"Object is not a Guppy definition: {obj}")
return ModulePointer(obj.compile(), 0)
def pytket(
self, input_circuit: Any
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Backs a function declaration by the given pytket circuit. The declaration
signature needs to match the circuit definition in terms of number of qubit
inputs and measurement outputs.
There is no linearity checking inside pytket circuit functions. Any measurements
inside the circuit get returned as bools, but the qubits do not get consumed and
the pytket circuit function does not require ownership. You should either make
sure you discard all qubits you know are measured during the circuit, or avoid
measurements in the circuit and measure in Guppy afterwards.
Note this decorator doesn't support passing inputs as arrays (use `load_pytket`
instead).
.. code-block:: python
from pytket import Circuit
from guppylang import guppy
circ = Circuit(1)
circ.H(0)
circ.measure_all()
@guppy.pytket(circ)
def guppy_circ(q: qubit) -> bool: ...
@guppy
def foo(q: qubit) -> bool:
return guppy_circ(q)"""
err_msg = "Only pytket circuits can be passed to guppy.pytket"
try:
import pytket
if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None
except ImportError:
raise TypeError(err_msg) from None
def func(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawPytketDef(DefId.fresh(), f.__name__, None, f, input_circuit)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
return func
def load_pytket(
self,
name: str,
input_circuit: Any,
*,
use_arrays: bool = True,
) -> GuppyFunctionDefinition[..., Any]:
"""Load a pytket :py:class:`~pytket.circuit.Circuit` as a Guppy function. By
default, each qubit register is represented by an array input (and each bit
register as an array output), with the order being determined lexicographically.
The default registers are 'q' and 'c' respectively. You can disable array usage
and pass individual qubits by passing `use_arrays=False`.
.. code-block:: python
from pytket import Circuit
from guppylang import guppy
circ = Circuit(2)
reg = circ.add_q_register("extra_reg", 3)
circ.measure_register(reg, "extra_bits")
guppy_circ = guppy.load_pytket("guppy_circ", circ)
@guppy
def foo(default_reg: array[qubit, 2],
extra_reg: array[qubit, 3]) -> array[bool, 3]:
# Note that the default_reg name is 'q' so it has to come after 'e...'
# lexicographically.
return guppy_circ(extra_reg, default_reg)
Any symbolic parameters in the circuit need to be passed as a lexicographically
sorted array (if arrays are enabled, else individually in that order) as values
of type `angle`.
The function name is determined by the function variable you bind the `
load_pytket`method call to, however the name string passed to the method should
match this variable for error reporting purposes.
There is no linearity checking inside pytket circuit functions. Any measurements
inside the circuit get returned as bools, but the qubits do not get consumed and
the pytket circuit function does not require ownership. You should either make
sure you discard all qubits you know are measured during the circuit, or avoid
measurements in the circuit and measure in Guppy afterwards.
"""
err_msg = "Only pytket circuits can be passed to guppy.load_pytket"
try:
import pytket
if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None
except ImportError:
raise TypeError(err_msg) from None
span = _find_load_call(DEF_STORE.sources)
defn = RawLoadPytketDef(
DefId.fresh(), name, None, span, input_circuit, use_arrays
)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr:
"""Helper function to parse expressions that are provided as strings.
Tries to infer the source location were the given string was defined by inspecting
the call stack.
"""
try:
expr_ast = ast.parse(ty_str, mode="eval").body
except SyntaxError:
raise SyntaxError(parse_err) from None
# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if caller_frame := get_calling_frame():
info = inspect.getframeinfo(caller_frame)
if caller_module := inspect.getmodule(caller_frame):
sources.add_file(info.filename)
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(expr_ast, source, info.filename, 1)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [expr_ast, *ast.walk(expr_ast)]:
node.lineno = node.end_lineno = info.lineno
node.col_offset = 0
node.end_col_offset = len(source_lines[info.lineno - 1]) - 1
return expr_ast
def _find_load_call(sources: SourceMap) -> Span | None:
"""Helper function to find location where pytket circuit was loaded.
Tries to define a source code span by inspecting the call stack.
"""
# Go back as first frame outside of compiler modules is 'pretty_errors_wrapped'.
if load_frame := get_calling_frame():
info = inspect.getframeinfo(load_frame)
filename = info.filename
lineno = info.lineno
sources.add_file(filename)
# If we don't support python <= 3.10, this can be done better with
# info.positions which gives you exact offsets.
# For now over approximate and make the span cover the entire line.
if load_module := inspect.getmodule(load_frame):
source_lines, _ = inspect.getsourcelines(load_module)
max_offset = len(source_lines[lineno - 1]) - 1
start = Loc(filename, lineno, 0)
end = Loc(filename, lineno, max_offset)
return Span(start, end)
return None
[docs]
def custom_guppy_decorator(f: F) -> F:
"""Decorator to mark user-defined decorators that wrap builtin `guppy` decorators.
Example:
.. code-block:: python
@custom_guppy_decorator
def my_guppy(f):
# Some custom logic here ...
return guppy(f)
@my_guppy
def main() -> int: ...
If the `custom_guppy_decorator` were missing, then the `@my_guppy` annotation would
not produce a valid guppy definition.
"""
f.__code__ = f.__code__.replace(co_name="__custom_guppy_decorator__")
return f
[docs]
def get_calling_frame() -> FrameType:
"""Finds the first frame that called this function outside the compiler modules."""
frame = inspect.currentframe()
while frame:
# Skip frame if we're inside a user-defined decorator that wraps the `guppy`
# decorator. Those are functions with a special `__code__.co_name` of
# "__custom_guppy_decorator__".
if frame.f_code.co_name == "__custom_guppy_decorator__":
frame = frame.f_back
continue
module = inspect.getmodule(frame)
if module is None:
return frame
if module.__file__ != __file__:
return frame
frame = frame.f_back
raise RuntimeError("Couldn't obtain stack frame for definition")
guppy = cast(_Guppy, _DummyGuppy()) if sphinx_running() else _Guppy()