import ast
import builtins
import inspect
from collections.abc import Callable, Sequence
from types import FrameType
from typing import Any, ParamSpec, TypeVar, cast
from guppylang_internals.ast_util import annotate_location
from guppylang_internals.compiler.core import (
CompilerContext,
)
from guppylang_internals.decorator import (
custom_function,
custom_type,
extend_type,
hugr_op,
)
from guppylang_internals.definition.common import DefId
from guppylang_internals.definition.const import RawConstDef
from guppylang_internals.definition.custom import (
CustomCallChecker,
CustomInoutCallCompiler,
RawCustomFunctionDef,
)
from guppylang_internals.definition.declaration import RawFunctionDecl
from guppylang_internals.definition.extern import RawExternDef
from guppylang_internals.definition.function import (
RawFunctionDef,
)
from guppylang_internals.definition.overloaded import OverloadedFunctionDef
from guppylang_internals.definition.parameter import (
ConstVarDef,
RawConstVarDef,
TypeVarDef,
)
from guppylang_internals.definition.pytket_circuits import (
RawLoadPytketDef,
RawPytketDef,
)
from guppylang_internals.definition.struct import RawStructDef
from guppylang_internals.definition.traced import RawTracedFunctionDef
from guppylang_internals.definition.ty import TypeDef
from guppylang_internals.dummy_decorator import _DummyGuppy, sphinx_running
from guppylang_internals.engine import DEF_STORE
from guppylang_internals.span import Loc, SourceMap, Span
from guppylang_internals.tys.arg import Argument
from guppylang_internals.tys.param import Parameter
from guppylang_internals.tys.subst import Inst
from guppylang_internals.tys.ty import (
FunctionType,
NoneType,
NumericType,
)
from hugr import ops
from hugr import tys as ht
from hugr import val as hv
from hugr.package import ModulePointer
from typing_extensions import dataclass_transform, deprecated
from guppylang.defs import (
GuppyDefinition,
GuppyFunctionDefinition,
GuppyTypeVarDefinition,
)
S = TypeVar("S")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
Decorator = Callable[[S], T]
AnyRawFunctionDef = (
RawFunctionDef,
RawCustomFunctionDef,
RawFunctionDecl,
RawPytketDef,
RawLoadPytketDef,
OverloadedFunctionDef,
)
__all__ = ("guppy", "custom_guppy_decorator")
[docs]
class _Guppy:
"""Class for the `@guppy` decorator."""
def __call__(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawFunctionDef(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
[docs]
def comptime(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawTracedFunctionDef(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
[docs]
@deprecated("Use @guppylang_internal.decorator.extend_type instead.")
def extend_type(self, defn: TypeDef) -> Callable[[type], type]:
return extend_type(defn)
@deprecated("Use @guppylang_internal.decorator.custom_type instead.")
def type(
self,
hugr_ty: ht.Type | Callable[[Sequence[Argument], CompilerContext], ht.Type],
name: str = "",
copyable: bool = True,
droppable: bool = True,
bound: ht.TypeBound | None = None,
params: Sequence[Parameter] | None = None,
) -> Callable[[type[T]], type[T]]:
return custom_type(hugr_ty, name, copyable, droppable, bound, params)
[docs]
@dataclass_transform()
def struct(self, cls: builtins.type[T]) -> builtins.type[T]:
defn = RawStructDef(DefId.fresh(), cls.__name__, None, cls)
frame = get_calling_frame()
DEF_STORE.register_def(defn, frame)
for val in cls.__dict__.values():
if isinstance(val, GuppyDefinition):
DEF_STORE.register_impl(defn.id, val.wrapped.name, val.id)
# Prior to Python 3.13, the `__firstlineno__` attribute on classes is not set.
# However, we need this information to precisely look up the source for the
# class later. If it's not there, we can set it from the calling frame:
if not hasattr(cls, "__firstlineno__"):
cls.__firstlineno__ = frame.f_lineno # type: ignore[attr-defined]
# We're pretending to return the class unchanged, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
[docs]
def type_var(
self,
name: str,
copyable: bool = True,
droppable: bool = True,
) -> TypeVar:
"""Creates a new type variable in a module."""
defn = TypeVarDef(DefId.fresh(), name, None, copyable, droppable)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
[docs]
def nat_var(self, name: str) -> TypeVar:
"""Creates a new const nat variable in a module."""
defn = ConstVarDef(DefId.fresh(), name, None, NumericType(NumericType.Kind.Nat))
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
[docs]
def const_var(self, name: str, ty: str) -> TypeVar:
"""Creates a new const type variable."""
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources
)
defn = RawConstVarDef(DefId.fresh(), name, None, type_ast)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a `typing.TypeVar`, but in fact we return a special
# `GuppyDefinition` that pretends to be a TypeVar at runtime
return GuppyTypeVarDefinition(defn, TypeVar(name)) # type: ignore[return-value]
[docs]
@deprecated("Use @guppylang_internal.decorator.custom_function instead.")
def custom(
self,
compiler: CustomInoutCallCompiler | None = None,
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
signature: FunctionType | None = None,
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
return custom_function(compiler, checker, higher_order_value, name, signature)
[docs]
@deprecated("Use @guppylang_internal.decorator.hugr_op instead.")
def hugr_op(
self,
op: Callable[[ht.FunctionType, Inst, CompilerContext], ops.DataflowOp],
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
signature: FunctionType | None = None,
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
return hugr_op(op, checker, higher_order_value, name, signature)
[docs]
def declare(self, f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawFunctionDecl(DefId.fresh(), f.__name__, None, f)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
[docs]
def overload(
self, *funcs: Any
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Collects multiple function definitions into one overloaded function.
Consider the following example:
.. code-block:: python
@guppy.declare
def variant1(x: int, y: int) -> int: ...
@guppy.declare
def variant2(x: float) -> int: ...
@guppy.overload(variant1, variant2)
def combined(): ...
Now, `combined` may be called with either one `float` or two `int` arguments,
delegating to the implementation with the matching signature:
.. code-block:: python
combined(4.2) # Calls `variant1`
combined(42, 43) # Calls `variant2`
Note that the compiler will pick the *first* implementation with matching
signature and ignore all following ones, even if they would also match. For
example, if we added a third variant
.. code-block:: python
@guppy.declare
def variant3(x: int) -> int: ...
@guppy.overload(variant1, variant2, variant3)
def combined_new(): ...
then a call `combined_new(42)` will still select the `variant1` implementation
`42` is a valid argument for `variant1` and `variant1` comes before `variant3`
in the `@guppy.overload` annotation.
"""
funcs = list(funcs)
if len(funcs) < 2:
raise ValueError("Overload requires at least two functions")
func_ids = []
for func in funcs:
if not isinstance(func, GuppyDefinition):
raise TypeError(f"Not a Guppy definition: {func}")
if not isinstance(func.wrapped, AnyRawFunctionDef):
raise TypeError(
f"Not a Guppy function definition: {func.wrapped.description} "
f"`{func.wrapped.name}`"
)
func_ids.append(func.id)
def dec(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
dummy_sig = FunctionType([], NoneType())
defn = OverloadedFunctionDef(
DefId.fresh(), f.__name__, None, dummy_sig, func_ids
)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
return dec
[docs]
def constant(self, name: str, ty: str, value: hv.Value) -> T: # type: ignore[type-var] # Since we're returning a free type variable
"""Adds a constant to a module, backed by a `hugr.val.Value`."""
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources
)
defn = RawConstDef(DefId.fresh(), name, None, type_ast, value)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a free type variable, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
def _extern(
self,
name: str,
ty: str,
symbol: str | None = None,
constant: bool = True,
) -> T: # type: ignore[type-var] # Since we're returning a free type variable
"""Adds an extern symbol to a module."""
type_ast = _parse_expr_string(
ty, f"Not a valid Guppy type: `{ty}`", DEF_STORE.sources
)
defn = RawExternDef(
DefId.fresh(), name, None, symbol or name, constant, type_ast
)
DEF_STORE.register_def(defn, get_calling_frame())
# We're pretending to return a free type variable, but in fact we return
# a `GuppyDefinition` that handles the comptime logic
return GuppyDefinition(defn) # type: ignore[return-value]
@deprecated(
"guppy.compile(foo) is deprecated and will be removed in a future version:"
" use foo.compile() instead."
)
def compile(self, obj: Any) -> ModulePointer:
"""Compiles a Guppy definition to Hugr."""
if not isinstance(obj, GuppyDefinition):
raise TypeError(f"Object is not a Guppy definition: {obj}")
return ModulePointer(obj.compile(), 0)
[docs]
def pytket(
self, input_circuit: Any
) -> Callable[[Callable[P, T]], GuppyFunctionDefinition[P, T]]:
"""Adds a pytket circuit function definition with explicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.pytket"
try:
import pytket
if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None
except ImportError:
raise TypeError(err_msg) from None
def func(f: Callable[P, T]) -> GuppyFunctionDefinition[P, T]:
defn = RawPytketDef(DefId.fresh(), f.__name__, None, f, input_circuit)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
return func
[docs]
def load_pytket(
self,
name: str,
input_circuit: Any,
*,
use_arrays: bool = True,
) -> GuppyFunctionDefinition[..., Any]:
"""Adds a pytket circuit function definition with implicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.load_pytket"
try:
import pytket
if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None
except ImportError:
raise TypeError(err_msg) from None
span = _find_load_call(DEF_STORE.sources)
defn = RawLoadPytketDef(
DefId.fresh(), name, None, span, input_circuit, use_arrays
)
DEF_STORE.register_def(defn, get_calling_frame())
return GuppyFunctionDefinition(defn)
def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.expr:
"""Helper function to parse expressions that are provided as strings.
Tries to infer the source location were the given string was defined by inspecting
the call stack.
"""
try:
expr_ast = ast.parse(ty_str, mode="eval").body
except SyntaxError:
raise SyntaxError(parse_err) from None
# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if caller_frame := get_calling_frame():
info = inspect.getframeinfo(caller_frame)
if caller_module := inspect.getmodule(caller_frame):
sources.add_file(info.filename)
source_lines, _ = inspect.getsourcelines(caller_module)
source = "".join(source_lines)
annotate_location(expr_ast, source, info.filename, 1)
# Modify the AST so that all sub-nodes span the entire line. We
# can't give a better location since we don't know the column
# offset of the `ty` argument
for node in [expr_ast, *ast.walk(expr_ast)]:
node.lineno = node.end_lineno = info.lineno
node.col_offset = 0
node.end_col_offset = len(source_lines[info.lineno - 1]) - 1
return expr_ast
def _find_load_call(sources: SourceMap) -> Span | None:
"""Helper function to find location where pytket circuit was loaded.
Tries to define a source code span by inspecting the call stack.
"""
# Go back as first frame outside of compiler modules is 'pretty_errors_wrapped'.
if load_frame := get_calling_frame():
info = inspect.getframeinfo(load_frame)
filename = info.filename
lineno = info.lineno
sources.add_file(filename)
# If we don't support python <= 3.10, this can be done better with
# info.positions which gives you exact offsets.
# For now over approximate and make the span cover the entire line.
if load_module := inspect.getmodule(load_frame):
source_lines, _ = inspect.getsourcelines(load_module)
max_offset = len(source_lines[lineno - 1]) - 1
start = Loc(filename, lineno, 0)
end = Loc(filename, lineno, max_offset)
return Span(start, end)
return None
[docs]
def custom_guppy_decorator(f: F) -> F:
"""Decorator to mark user-defined decorators that wrap builtin `guppy` decorators.
Example:
.. code-block:: python
@custom_guppy_decorator
def my_guppy(f):
# Some custom logic here ...
return guppy(f)
@my_guppy
def main() -> int: ...
If the `custom_guppy_decorator` were missing, then the `@my_guppy` annotation would
not produce a valid guppy definition.
"""
f.__code__ = f.__code__.replace(co_name="__custom_guppy_decorator__")
return f
[docs]
def get_calling_frame() -> FrameType:
"""Finds the first frame that called this function outside the compiler modules."""
frame = inspect.currentframe()
while frame:
# Skip frame if we're inside a user-defined decorator that wraps the `guppy`
# decorator. Those are functions with a special `__code__.co_name` of
# "__custom_guppy_decorator__".
if frame.f_code.co_name == "__custom_guppy_decorator__":
frame = frame.f_back
continue
module = inspect.getmodule(frame)
if module is None:
return frame
if module.__file__ != __file__:
return frame
frame = frame.f_back
raise RuntimeError("Couldn't obtain stack frame for definition")
guppy = cast(_Guppy, _DummyGuppy()) if sphinx_running() else _Guppy()