Source code for guppylang.decorator

import ast
import builtins
import inspect
from collections.abc import Callable, Sequence
from types import FrameType
from typing import Any, ParamSpec, TypeVar, cast

from guppylang_internals.ast_util import annotate_location
from guppylang_internals.compiler.core import (
    CompilerContext,
)
from guppylang_internals.decorator import (
    custom_function,
    custom_type,
    extend_type,
    hugr_op,
)
from guppylang_internals.definition.common import DefId
from guppylang_internals.definition.const import RawConstDef
from guppylang_internals.definition.custom import (
    CustomCallChecker,
    CustomInoutCallCompiler,
    RawCustomFunctionDef,
)
from guppylang_internals.definition.declaration import RawFunctionDecl
from guppylang_internals.definition.extern import RawExternDef
from guppylang_internals.definition.function import (
    RawFunctionDef,
)
from guppylang_internals.definition.overloaded import OverloadedFunctionDef
from guppylang_internals.definition.parameter import (
    ConstVarDef,
    RawConstVarDef,
    TypeVarDef,
)
from guppylang_internals.definition.pytket_circuits import (
    RawLoadPytketDef,
    RawPytketDef,
)
from guppylang_internals.definition.struct import RawStructDef
from guppylang_internals.definition.traced import RawTracedFunctionDef
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.dummy_decorator import _DummyGuppy, sphinx_running
from guppylang_internals.engine import DEF_STORE
from guppylang_internals.span import Loc, SourceMap, Span
from guppylang_internals.tys.arg import Argument
from guppylang_internals.tys.param import Parameter
from guppylang_internals.tys.subst import Inst
from guppylang_internals.tys.ty import (
    FunctionType,
    NoneType,
    NumericType,
)
from hugr import ops
from hugr import tys as ht
from hugr import val as hv
from hugr.package import ModulePointer
from typing_extensions import dataclass_transform, deprecated

from guppylang.defs import (
    GuppyDefinition,
    GuppyFunctionDefinition,
    GuppyTypeVarDefinition,
)

S = TypeVar("S")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
Decorator = Callable[[S], T]

AnyRawFunctionDef = (
    RawFunctionDef,
    RawCustomFunctionDef,
    RawFunctionDecl,
    RawPytketDef,
    RawLoadPytketDef,
    OverloadedFunctionDef,
)

__all__ = ("guppy", "custom_guppy_decorator")


[docs] class _Guppy: """Class for the `@guppy` decorator.""" def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn)
[docs] def comptime(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: defn = RawTracedFunctionDef(DefId.fresh(), f.__name__, None, f) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn)
[docs] @deprecated("Use @guppylang_internal.decorator.extend_type instead.") def extend_type(self, defn: TypeDef) -> Callable[[type], type]: return extend_type(defn)
@deprecated("Use @guppylang_internal.decorator.custom_type instead.") def type( self, hugr_ty: ht.Type | Callable[[Sequence[Argument], CompilerContext], ht.Type], name: str = "", copyable: bool = True, droppable: bool = True, bound: ht.TypeBound | None = None, params: Sequence[Parameter] | None = None, ) -> Callable[[type[T]], type[T]]: return custom_type(hugr_ty, name, copyable, droppable, bound, params)
[docs] @dataclass_transform() def struct(self, cls: builtins.type[T]) -> builtins.type[T]: defn = RawStructDef(DefId.fresh(), cls.__name__, None, cls) frame = get_calling_frame() DEF_STORE.register_def(defn, frame) for val in cls.__dict__.values(): if isinstance(val, GuppyDefinition): DEF_STORE.register_impl(defn.id, val.wrapped.name, val.id) # Prior to Python 3.13, the `__firstlineno__` attribute on classes is not set. # However, we need this information to precisely look up the source for the # class later. If it's not there, we can set it from the calling frame: if not hasattr(cls, "__firstlineno__"): cls.__firstlineno__ = frame.f_lineno # type: ignore[attr-defined] # We're pretending to return the class unchanged, but in fact we return # a `GuppyDefinition` that handles the comptime logic return GuppyDefinition(defn) # type: ignore[return-value]
[docs] def type_var( self, name: str, copyable: bool = True, droppable: bool = True, ) -> TypeVar: """Creates a new type variable in a module.""" defn = TypeVarDef(DefId.fresh(), name, None, copyable, droppable) DEF_STORE.register_def(defn, get_calling_frame()) # We're pretending to return a `typing.TypeVar`, but in fact we return a special # `GuppyDefinition` that pretends to be a TypeVar at runtime return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
[docs] def nat_var(self, name: str) -> TypeVar: """Creates a new const nat variable in a module.""" defn = ConstVarDef(DefId.fresh(), name, None, NumericType(NumericType.Kind.Nat)) DEF_STORE.register_def(defn, get_calling_frame()) # We're pretending to return a `typing.TypeVar`, but in fact we return a special # `GuppyDefinition` that pretends to be a TypeVar at runtime return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
[docs] def const_var(self, name: str, ty: str) -> TypeVar: """Creates a new const type variable.""" type_ast = _parse_expr_string( ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources ) defn = RawConstVarDef(DefId.fresh(), name, None, type_ast) DEF_STORE.register_def(defn, get_calling_frame()) # We're pretending to return a `typing.TypeVar`, but in fact we return a special # `GuppyDefinition` that pretends to be a TypeVar at runtime return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
[docs] @deprecated("Use @guppylang_internal.decorator.custom_function instead.") def custom( self, compiler: CustomInoutCallCompiler | None = None, checker: CustomCallChecker | None = None, higher_order_value: bool = True, name: str = "", signature: FunctionType | None = None, ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: return custom_function(compiler, checker, higher_order_value, name, signature)
[docs] @deprecated("Use @guppylang_internal.decorator.hugr_op instead.") def hugr_op( self, op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp], checker: CustomCallChecker | None = None, higher_order_value: bool = True, name: str = "", signature: FunctionType | None = None, ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: return hugr_op(op, checker, higher_order_value, name, signature)
[docs] def declare(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: defn = RawFunctionDecl(DefId.fresh(), f.__name__, None, f) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn)
[docs] def overload( self, *funcs: Any ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: """Collects multiple function definitions into one overloaded function. Consider the following example: .. code-block:: python @guppy.declare def variant1(x: int, y: int) -> int: ... @guppy.declare def variant2(x: float) -> int: ... @guppy.overload(variant1, variant2) def combined(): ... Now, `combined` may be called with either one `float` or two `int` arguments, delegating to the implementation with the matching signature: .. code-block:: python combined(4.2) # Calls `variant1` combined(42, 43) # Calls `variant2` Note that the compiler will pick the *first* implementation with matching signature and ignore all following ones, even if they would also match. For example, if we added a third variant .. code-block:: python @guppy.declare def variant3(x: int) -> int: ... @guppy.overload(variant1, variant2, variant3) def combined_new(): ... then a call `combined_new(42)` will still select the `variant1` implementation `42` is a valid argument for `variant1` and `variant1` comes before `variant3` in the `@guppy.overload` annotation. """ funcs = list(funcs) if len(funcs) < 2: raise ValueError("Overload requires at least two functions") func_ids = [] for func in funcs: if not isinstance(func, GuppyDefinition): raise TypeError(f"Not a Guppy definition: {func}") if not isinstance(func.wrapped, AnyRawFunctionDef): raise TypeError( f"Not a Guppy function definition: {func.wrapped.description} " f"`{func.wrapped.name}`" ) func_ids.append(func.id) def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: dummy_sig = FunctionType([], NoneType()) defn = OverloadedFunctionDef( DefId.fresh(), f.__name__, None, dummy_sig, func_ids ) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) return dec
[docs] def constant(self, name: str, ty: str, value: hv.Value) -> T: # type: ignore[type-var] # Since we're returning a free type variable """Adds a constant to a module, backed by a `hugr.val.Value`.""" type_ast = _parse_expr_string( ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources ) defn = RawConstDef(DefId.fresh(), name, None, type_ast, value) DEF_STORE.register_def(defn, get_calling_frame()) # We're pretending to return a free type variable, but in fact we return # a `GuppyDefinition` that handles the comptime logic return GuppyDefinition(defn) # type: ignore[return-value]
def _extern( self, name: str, ty: str, symbol: str | None = None, constant: bool = True, ) -> T: # type: ignore[type-var] # Since we're returning a free type variable """Adds an extern symbol to a module.""" type_ast = _parse_expr_string( ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources ) defn = RawExternDef( DefId.fresh(), name, None, symbol or name, constant, type_ast ) DEF_STORE.register_def(defn, get_calling_frame()) # We're pretending to return a free type variable, but in fact we return # a `GuppyDefinition` that handles the comptime logic return GuppyDefinition(defn) # type: ignore[return-value] @deprecated( "guppy.compile(foo) is deprecated and will be removed in a future version:" " use foo.compile() instead." ) def compile(self, obj: Any) -> ModulePointer: """Compiles a Guppy definition to Hugr.""" if not isinstance(obj, GuppyDefinition): raise TypeError(f"Object is not a Guppy definition: {obj}") return ModulePointer(obj.compile(), 0)
[docs] def pytket( self, input_circuit: Any ) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]: """Adds a pytket circuit function definition with explicit signature.""" err_msg = "Only pytket circuits can be passed to guppy.pytket" try: import pytket if not isinstance(input_circuit, pytket.circuit.Circuit): raise TypeError(err_msg) from None except ImportError: raise TypeError(err_msg) from None def func(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]: defn = RawPytketDef(DefId.fresh(), f.__name__, None, f, input_circuit) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn) return func
[docs] def load_pytket( self, name: str, input_circuit: Any, *, use_arrays: bool = True, ) -> GuppyFunctionDefinition[..., Any]: """Adds a pytket circuit function definition with implicit signature.""" err_msg = "Only pytket circuits can be passed to guppy.load_pytket" try: import pytket if not isinstance(input_circuit, pytket.circuit.Circuit): raise TypeError(err_msg) from None except ImportError: raise TypeError(err_msg) from None span = _find_load_call(DEF_STORE.sources) defn = RawLoadPytketDef( DefId.fresh(), name, None, span, input_circuit, use_arrays ) DEF_STORE.register_def(defn, get_calling_frame()) return GuppyFunctionDefinition(defn)
def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr: """Helper function to parse expressions that are provided as strings. Tries to infer the source location were the given string was defined by inspecting the call stack. """ try: expr_ast = ast.parse(ty_str, mode="eval").body except SyntaxError: raise SyntaxError(parse_err) from None # Try to annotate the type AST with source information. This requires us to # inspect the stack frame of the caller if caller_frame := get_calling_frame(): info = inspect.getframeinfo(caller_frame) if caller_module := inspect.getmodule(caller_frame): sources.add_file(info.filename) source_lines, _ = inspect.getsourcelines(caller_module) source = "".join(source_lines) annotate_location(expr_ast, source, info.filename, 1) # Modify the AST so that all sub-nodes span the entire line. We # can't give a better location since we don't know the column # offset of the `ty` argument for node in [expr_ast, *ast.walk(expr_ast)]: node.lineno = node.end_lineno = info.lineno node.col_offset = 0 node.end_col_offset = len(source_lines[info.lineno - 1]) - 1 return expr_ast def _find_load_call(sources: SourceMap) -> Span | None: """Helper function to find location where pytket circuit was loaded. Tries to define a source code span by inspecting the call stack. """ # Go back as first frame outside of compiler modules is 'pretty_errors_wrapped'. if load_frame := get_calling_frame(): info = inspect.getframeinfo(load_frame) filename = info.filename lineno = info.lineno sources.add_file(filename) # If we don't support python <= 3.10, this can be done better with # info.positions which gives you exact offsets. # For now over approximate and make the span cover the entire line. if load_module := inspect.getmodule(load_frame): source_lines, _ = inspect.getsourcelines(load_module) max_offset = len(source_lines[lineno - 1]) - 1 start = Loc(filename, lineno, 0) end = Loc(filename, lineno, max_offset) return Span(start, end) return None
[docs] def custom_guppy_decorator(f: F) -> F: """Decorator to mark user-defined decorators that wrap builtin `guppy` decorators. Example: .. code-block:: python @custom_guppy_decorator def my_guppy(f): # Some custom logic here ... return guppy(f) @my_guppy def main() -> int: ... If the `custom_guppy_decorator` were missing, then the `@my_guppy` annotation would not produce a valid guppy definition. """ f.__code__ = f.__code__.replace(co_name="__custom_guppy_decorator__") return f
[docs] def get_calling_frame() -> FrameType: """Finds the first frame that called this function outside the compiler modules.""" frame = inspect.currentframe() while frame: # Skip frame if we're inside a user-defined decorator that wraps the `guppy` # decorator. Those are functions with a special `__code__.co_name` of # "__custom_guppy_decorator__". if frame.f_code.co_name == "__custom_guppy_decorator__": frame = frame.f_back continue module = inspect.getmodule(frame) if module is None: return frame if module.__file__ != __file__: return frame frame = frame.f_back raise RuntimeError("Couldn't obtain stack frame for definition")
guppy = cast(_Guppy, _DummyGuppy()) if sphinx_running() else _Guppy()