Source code for pytket.wasm.wasm

# Copyright 2019-2024 Cambridge Quantum Computing
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import hashlib
from functools import cached_property
from os.path import exists

from qwasm import (  # type: ignore
    LANG_TYPE_EMPTY,
    LANG_TYPE_F32,
    LANG_TYPE_F64,
    LANG_TYPE_I32,
    LANG_TYPE_I64,
    SEC_EXPORT,
    SEC_FUNCTION,
    SEC_TYPE,
    decode_module,
)
from typing_extensions import deprecated


class WasmModuleHandler:
    """Construct and optionally check a wasm module for use in wasm Ops."""

    checked: bool
    _int_size: int
    _wasm_module: bytes
    _functions: dict[str, tuple[int, int]]
    _unsupported_functions: list[str]

    type_lookup = {
        LANG_TYPE_I32: "i32",
        LANG_TYPE_I64: "i64",
        LANG_TYPE_F32: "f32",
        LANG_TYPE_F64: "f64",
        LANG_TYPE_EMPTY: None,
    }

    def __init__(
        self, wasm_module: bytes, check: bool = True, int_size: int = 32
    ) -> None:
        """
        Construct a wasm module handler

        :param wasm_module: A wasm module in binary format.
        :type wasm_module: bytes
        :param check: If ``True`` checks file for compatibility with wasm
          standards. If ``False`` checks are skipped.
        :type check: bool
        :param int_size: length of the integer that is used in the wasm file
        :type int_size: int
        """
        self._int_size = int_size
        if int_size == 32:
            self._int_type = self.type_lookup[LANG_TYPE_I32]
        elif int_size == 64:
            self._int_type = self.type_lookup[LANG_TYPE_I64]
        else:
            raise ValueError(
                "given integer length not valid, only 32 and 64 are allowed"
            )

        # stores the names of the functions mapped
        #  to the number of parameters and the number of return values
        self._functions = dict()

        # contains the list of functions that are not allowed
        # to use in pytket (because of types that are not integers
        # of the supplied int_size.)
        self._unsupported_functions = []

        self._wasm_module = wasm_module
        self.checked = False

        if check:
            self.check()

    def check(self) -> None:
        """Collect functions from the module that can be used with pytket.

        Populates the internal list of supported and unsupported functions
        and marks the module as checked so that subsequent checking is not
        required.
        """
        if self.checked:
            return

        function_signatures: list = []
        function_names: list = []
        _func_lookup = {}
        mod_iter = iter(decode_module(self._wasm_module))
        _, _ = next(mod_iter)

        for _, cur_sec_data in mod_iter:
            # read in list of function signatures
            if cur_sec_data.id == SEC_TYPE:
                for idx, entry in enumerate(cur_sec_data.payload.entries):
                    function_signatures.append({})
                    function_signatures[idx]["parameter_types"] = [
                        self.type_lookup[pt] for pt in entry.param_types
                    ]
                    if entry.return_count > 1:
                        if (
                            isinstance(entry.return_type, list)
                            and len(entry.return_type) == entry.return_count
                        ):
                            function_signatures[idx]["return_types"] = [
                                self.type_lookup[rt] for rt in entry.return_type
                            ]
                        elif isinstance(entry.return_type, int):
                            function_signatures[idx]["return_types"] = [
                                self.type_lookup[entry.return_type]
                            ] * entry.return_count
                        else:
                            raise ValueError(
                                "Only parameter and return values of "
                                f"i{self._int_size} types are"
                                f" allowed, found type: {entry.return_type}"
                            )
                    elif entry.return_count == 1:
                        function_signatures[idx]["return_types"] = [
                            self.type_lookup[entry.return_type]
                        ]
                    else:
                        function_signatures[idx]["return_types"] = []

            # read in list of function names
            elif cur_sec_data.id == SEC_EXPORT:
                f_idx = 0
                for _, entry in enumerate(cur_sec_data.payload.entries):
                    if entry.kind == 0:
                        f_name = entry.field_str.tobytes().decode()
                        function_names.append(f_name)
                        _func_lookup[f_name] = (f_idx, entry.index)
                        f_idx += 1

            # read in map of function signatures to function names
            elif cur_sec_data.id == SEC_FUNCTION:
                self._function_types = cur_sec_data.payload.types

        for x in function_names:
            # check for only integer type in parameters and return values
            supported_function = True
            idx = _func_lookup[x][1]

            if idx >= len(self._function_types):
                raise ValueError("invalid wasm file")

            for t in function_signatures[self._function_types[idx]]["parameter_types"]:
                if t != self._int_type:
                    supported_function = False
            for t in function_signatures[self._function_types[idx]]["return_types"]:
                if t != self._int_type:
                    supported_function = False

            if len(function_signatures[self._function_types[idx]]["return_types"]) > 1:
                supported_function = False

            if supported_function:
                self._functions[x] = (
                    len(
                        function_signatures[self._function_types[idx]][
                            "parameter_types"
                        ]
                    ),
                    len(function_signatures[self._function_types[idx]]["return_types"]),
                )

            if not supported_function:
                self._unsupported_functions.append(x)

        if "init" not in self._functions:
            raise ValueError("wasm file needs to contain a function named 'init'")

        if self._functions["init"][0] != 0:
            raise ValueError("init function should not have any parameter")

        if self._functions["init"][1] != 0:
            raise ValueError("init function should not have any results")

        # Mark the module as checked, which indicates that function
        # signatures are available and that it does not need
        # to be checked again.
        self.checked = True

    @property
    @deprecated("Use public property `checked` instead.")
    def _check_file(self) -> bool:
        return self.checked

    def __str__(self) -> str:
        """str representation of the wasm module"""
        return self.uid

    def __repr__(self) -> str:
        """str representation of the contents of the wasm file."""
        if not self.checked:
            return f"Unchecked wasm module file with the uid {self.uid}"

        result = f"Functions in wasm file with the uid {self.uid}:\n"
        for x in self.functions:
            result += f"function '{x}' with "
            result += f"{self.functions[x][0]} i{self._int_size} parameter(s)"
            result += f" and {self.functions[x][1]} i{self._int_size} return value(s)\n"

        for x in self.unsupported_functions:
            result += (
                f"unsupported function with invalid "
                f"parameter or result type: '{x}' \n"
            )

        return result

    def bytecode(self) -> bytes:
        """The wasm content as bytecode"""
        return self._wasm_module

    @cached_property
    def bytecode_base64(self) -> bytes:
        """The wasm content as base64 encoded bytecode."""
        return base64.b64encode(self._wasm_module)

    @property
    @deprecated("Use public property `bytecode_base64` instead.")
    def _wasm_file_encoded(self) -> bytes:
        return self.bytecode_base64

    @cached_property
    def uid(self) -> str:
        """A unique identifier for the module calculated from its' checksum."""
        return hashlib.sha256(self.bytecode_base64).hexdigest()

    @property
    @deprecated("Use public property `uid` instead.")
    def _wasmfileuid(self) -> str:
        return self.uid

    def check_function(
        self, function_name: str, number_of_parameters: int, number_of_returns: int
    ) -> bool:
        """
        Checks a given function name and signature if it is included and the
        module has previously been checked.

        If the module has not been checked this function with will raise a
        ValueError.

        :param function_name: name of the function that is checked
        :type function_name: str
        :param number_of_parameters: number of integer parameters of the function
        :type number_of_parameters: int
        :param number_of_returns: number of integer return values of the function
        :type number_of_returns: int
        :return: true if the signature and the name of the function is correct"""
        if not self.checked:
            raise ValueError(
                "Cannot retrieve functions from an unchecked wasm module."
                " Please call .check() first."
            )

        return (
            (function_name in self._functions)
            and (self._functions[function_name][0] == number_of_parameters)
            and (self._functions[function_name][1] == number_of_returns)
        )

    @property
    def functions(self) -> dict[str, tuple[int, int]]:
        """Retrieve the names of functions with the number of input and out arguments.

        If the module has not been checked this function with will raise a
        ValueError.
        """
        if not self.checked:
            raise ValueError(
                "Cannot retrieve functions from an unchecked wasm module."
                " Please call .check() first."
            )
        return self._functions

    @property
    def unsupported_functions(self) -> list[str]:
        """Retrieve the names of unsupported functions as a list of strings.

        If the module has not been checked this function with will raise a
        ValueError.
        """
        if not self.checked:
            raise ValueError(
                "Cannot retrieve functions from an unchecked wasm module."
                " Please call .check() first."
            )
        return self._unsupported_functions


[docs] class WasmFileHandler(WasmModuleHandler): """Construct and optionally check a wasm module from a file for use in wasm Ops."""
[docs] def __init__(self, filepath: str, check_file: bool = True, int_size: int = 32): """ Construct a wasm file handler using a filepath to read a wasm module into memory. :param filepath: Path to the wasm file :type filepath: str :param check_file: If ``True`` checks file for compatibility with wasm standards. If ``False`` checks are skipped. :type check_file: bool :param int_size: length of the integer that is used in the wasm file :type int_size: int """ if not exists(filepath): raise ValueError("wasm file not found at given path") with open(filepath, "rb") as file: self._wasm_file: bytes = file.read() super().__init__(self._wasm_file, check_file, int_size)