lambeq.training

class lambeq.training.BinaryCrossEntropyLoss(sparse: bool = False, use_jax: bool = False, epsilon: float = 1e-09)[source]

Bases: CrossEntropyLoss

Binary cross-entropy loss function.

Parameters:
y_pred: np.ndarray or jnp.ndarray

Predicted labels from model. When sparse is False, expected to be of shape [batch_size, 2], where each row is a probability distribution. When sparse is True, expected to be of shape [batch_size, ] where each element indicates P(1).

y_true: np.ndarray or jnp.ndarray

Ground truth labels. When sparse is False, expected to be of shape [batch_size, 2], where each row is a one-hot vector. When sparse is True, expected to be of shape [batch_size, ] where each element is an integer indicating class label.

__call__(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float

Call self as a function.

__init__(sparse: bool = False, use_jax: bool = False, epsilon: float = 1e-09) None[source]

Initialise a binary cross-entropy loss function.

Parameters:
sparsebool, default: False
If True, each input element indicates P(1), else the

probability distribution over classes is expected.

use_jaxbool, default: False

Whether to use the Jax variant of numpy.

epsilonfloat, default: 1e-9

Smoothing constant used to prevent calculating log(0).

calculate_loss(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float[source]

Calculate value of BCE loss function.

class lambeq.training.Checkpoint[source]

Bases: Mapping

Checkpoint class.

Attributes:
entriesdict

All data, stored as part of the checkpoint.

__init__() None[source]

Initialise a Checkpoint.

add_many(values: Mapping[str, Any]) None[source]

Adds several values into the checkpoint.

Parameters:
valuesMapping from str to any

The values to be added into the checkpoint.

classmethod from_file(path: str | PathLike[str]) Checkpoint[source]

Load the checkpoint contents from the file.

Parameters:
pathstr or PathLike

Path to the checkpoint file.

Raises:
FileNotFoundError

If no file is found at the given path.

get(k[, d]) D[k] if k in D, else d.  d defaults to None.
items() a set-like object providing a view on D's items
keys() a set-like object providing a view on D's keys
to_file(path: str | PathLike[str]) None[source]

Save entries to a file and deletes the in-memory copy.

Parameters:
pathstr or PathLike

Path to the checkpoint file.

values() an object providing a view on D's values
class lambeq.training.CrossEntropyLoss(use_jax: bool = False, epsilon: float = 1e-09)[source]

Bases: LossFunction

Multiclass cross-entropy loss function.

Parameters:
y_pred: np.ndarray or jnp.ndarray

Predicted labels from model. Expected to be of shape [batch_size, n_classes], where each row is a probability distribution.

y_true: np.ndarray or jnp.ndarray

Ground truth labels. Expected to be of shape [batch_size, n_classes], where each row is a one-hot vector.

__call__(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float

Call self as a function.

__init__(use_jax: bool = False, epsilon: float = 1e-09) None[source]

Initialise a multiclass cross-entropy loss function.

Parameters:
use_jaxbool, default: False

Whether to use the Jax variant of numpy.

epsilonfloat, default: 1e-9

Smoothing constant used to prevent calculating log(0).

calculate_loss(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float[source]

Calculate value of CE loss function.

class lambeq.training.Dataset(data: list[Any], targets: list[Any], batch_size: int = 0, shuffle: bool = True)[source]

Bases: object

Dataset class for the training of a lambeq model.

Data is returned in the format of lambeq’s numerical backend, which by default is set to NumPy. For example, to access the dataset as PyTorch tensors:

>>> from lambeq.backend import numerical_backend
>>> dataset = Dataset(['data1'], [[0, 1, 2, 3]])
>>> with numerical_backend.backend('pytorch'):
...     print(dataset[0])  # becomes pytorch tensor
('data1', tensor([0, 1, 2, 3]))
>>> print(dataset[0])  # numpy array again
('data1', array([0, 1, 2, 3]))
__init__(data: list[Any], targets: list[Any], batch_size: int = 0, shuffle: bool = True) None[source]

Initialise a Dataset for lambeq training.

Parameters:
datalist

Data used for training.

targetslist

List of labels.

batch_sizeint, default: 0

Batch size for batch generation, by default full dataset.

shufflebool, default: True

Enable data shuffling during training.

Raises:
ValueError

When ‘data’ and ‘targets’ do not match in size.

static shuffle_data(data: list[Any], targets: list[Any]) tuple[list[Any], list[Any]][source]

Shuffle a given dataset.

Parameters:
datalist

List of data points.

targetslist

List of labels.

Returns:
Tuple of list and list

The shuffled dataset.

class lambeq.training.LossFunction(use_jax: bool = False)[source]

Bases: ABC

Loss function base class.

Attributes:
backendModuleType
The module to use for array numerical functions.

Either numpy or jax.numpy.

__call__(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float[source]

Call self as a function.

__init__(use_jax: bool = False) None[source]

Initialise a loss function.

Parameters:
use_jaxbool, default: False

Whether to use the Jax variant of numpy as backend.

abstract calculate_loss(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float[source]

Calculate value of loss function.

class lambeq.training.MSELoss(use_jax: bool = False)[source]

Bases: LossFunction

Mean squared error loss function.

Parameters:
y_pred: np.ndarray or jnp.ndarray

Predicted values from model. Shape must match y_true.

y_true: np.ndarray or jnp.ndarray

Ground truth values.

__call__(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float

Call self as a function.

__init__(use_jax: bool = False) None

Initialise a loss function.

Parameters:
use_jaxbool, default: False

Whether to use the Jax variant of numpy as backend.

calculate_loss(y_pred: np.ndarray | jnp.ndarray, y_true: np.ndarray | jnp.ndarray) float[source]

Calculate value of MSE loss function.

class lambeq.training.Model[source]

Bases: ABC

Model base class.

Attributes:
symbolslist of symbols

A sorted list of all Symbols occuring in the data.

weightsCollection

A data structure containing the numeric values of the model’s parameters.

__call__(*args: Any, **kwds: Any) Any[source]

Call self as a function.

__init__() None[source]

Initialise an instance of Model base class.

abstract forward(x: list[Any]) Any[source]

The forward pass of the model.

classmethod from_checkpoint(checkpoint_path: StrPathT, **kwargs: Any) Model[source]

Load the weights and symbols from a training checkpoint.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model[source]

Build model from a list of Diagrams.

Parameters:
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

abstract get_diagram_output(diagrams: list[Diagram]) Any[source]

Return the diagram prediction.

Parameters:
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

abstract initialise_weights() None[source]

Initialise the weights of the model.

load(checkpoint_path: StrPathT) None[source]

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: StrPathT) None[source]

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

class lambeq.training.NelderMeadOptimizer(*, model: QuantumModel, loss_fn: Callable[[Any, Any], float], hyperparams: dict[str, float] | None = None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]

Bases: Optimizer

An optimizer based on the Nelder-Mead algorithm.

This implementation is based heavily on SciPy’s optimize.minimize.

__init__(*, model: QuantumModel, loss_fn: Callable[[Any, Any], float], hyperparams: dict[str, float] | None = None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None) None[source]

Initialise the Nelder-Mead optimizer.

The hyperparameters may contain the following key-value pairs:

  • adaptive: bool, default: False

    Adjust the algorithm’s parameters based on the dimensionality of the problem. This is particularly helpful when minimizing functions in high-dimensional spaces.

  • maxfev: int, default: 1000

    Maximum number of function evaluations allowed.

  • initial_simplex: ArrayLike (N+1, N), default: None

    If provided, replaces the initial model weights. Each row should contain the coordinates of the i`th vertex of the `N+1 vertices in the simplex, where N is the dimension.

  • xatol: float, default: 1e-4

    The acceptable level of absolute error in the optimal model weights (optimal solution) between iterations that indicates convergence.

  • fatol: float, default: 1e-4

    The acceptable level of absolute error in the loss value between iterations that indicates convergence.

Parameters:
modelQuantumModel

A lambeq quantum model.

hyperparamsdict of str to float

A dictionary containing the models hyperparameters.

loss_fnCallable[[ArrayLike, ArrayLike], float]]

A loss function of form loss(prediction, labels).

boundsArrayLike, optional

The range of each of the model parameters.

Raises:
ValueError
  • If the hyperparameters are not set correctly, or if the length of bounds does not match the number of the model parameters.

  • If the lower bounds are greater than the upper bounds.

  • If the initial simplex is not a 2D array.

  • If the initial simplex does not have N+1 rows, where N is the number of model parameters.

Warning
  • If the initial model weights are not within the bounds.

References

Gao, Fuchang & Han, Lixing. (2012). Implementing the Nelder-Mead Simplex Algorithm with Adaptive Parameters. Computational Optimization and Applications, 51. 259-277. 10.1007/s10589-010-9329-3.

backward(batch: tuple[Iterable[Any], ndarray]) float[source]

Calculate the gradients of the loss function.

The gradients are calculated with respect to the model parameters.

Parameters:
batchtuple of Iterable and numpy.ndarray

Current batch. Contains an Iterable of diagrams in index 0, and the targets in index 1.

Returns:
float

The calculated loss.

bounds: ndarray | None
load_state_dict(state_dict: Mapping[str, Any]) None[source]

Load state of the optimizer from the state dictionary.

Parameters:
state_dictdict

A dictionary containing a snapshot of the optimizer state.

model: QuantumModel
objective(x: Iterable[Any], y: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], w: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes]) float[source]

The objective function to be minimized.

Parameters:
xArrayLike

The input data.

yArrayLike

The labels.

wArrayLike

The model parameters.

Returns:
result: float

The result of the objective function.

Raises:
ValueError

If the objective function does not return a scalar value.

project(x: ndarray) ndarray[source]
state_dict() dict[str, Any][source]

Return optimizer states as dictionary.

Returns:
dict

A dictionary containing the current state of the optimizer.

step() None[source]

Perform optimisation step.

update_hyper_params() None[source]

Update the hyperparameters of the Nelder-Mead algorithm.

zero_grad() None

Reset the gradients to zero.

class lambeq.training.NumpyModel(use_jit: bool = False)[source]

Bases: QuantumModel

A lambeq model for an exact classical simulation of a quantum pipeline.

__call__(*args: Any, **kwargs: Any) Any

Call self as a function.

__init__(use_jit: bool = False) None[source]

Initialise an NumpyModel.

Parameters:
use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation.

forward(x: list[Diagram]) Any[source]

Perform default forward pass of a lambeq model.

In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method.

Parameters:
xlist of Diagram

The Circuits to be evaluated.

Returns:
numpy.ndarray

Array containing model’s prediction.

classmethod from_checkpoint(checkpoint_path: StrPathT, **kwargs: Any) Model

Load the weights and symbols from a training checkpoint.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters:
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

get_diagram_output(diagrams: list[Diagram]) jnp.ndarray | numpy.ndarray[source]

Return the exact prediction for each diagram.

Parameters:
diagramslist of Diagram

The Circuits to be evaluated.

Returns:
np.ndarray

Resulting array.

Raises:
ValueError

If model.weights or model.symbols are not initialised.

initialise_weights() None

Initialise the weights of the model.

Raises:
ValueError

If model.symbols are not initialised.

load(checkpoint_path: StrPathT) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: StrPathT) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

weights: np.ndarray
class lambeq.training.Optimizer(*, model: Model, loss_fn: Callable[[Any, Any], float], hyperparams: dict[Any, Any] | None = None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]

Bases: ABC

Optimizer base class.

__init__(*, model: Model, loss_fn: Callable[[Any, Any], float], hyperparams: dict[Any, Any] | None = None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None) None[source]

Initialise the optimizer base class.

Parameters:
modelQuantumModel

A lambeq model.

loss_fnCallable

A loss function of form loss(prediction, labels).

hyperparamsdict of str to float, optional

A dictionary containing the models hyperparameters.

boundsArrayLike, optional

The range of each of the model’s parameters.

abstract backward(batch: tuple[Iterable[Any], ndarray]) float[source]

Calculate the gradients of the loss function.

The gradient is calculated with respect to the model parameters.

Parameters:
batchtuple of list and numpy.ndarray

Current batch.

Returns:
float

The calculated loss.

abstract load_state_dict(state: Mapping[str, Any]) None[source]

Load state of the optimizer from the state dictionary.

abstract state_dict() dict[str, Any][source]

Return optimizer states as dictionary.

abstract step() None[source]

Perform optimisation step.

zero_grad() None[source]

Reset the gradients to zero.

class lambeq.training.PennyLaneModel(probabilities: bool = True, normalize: bool = True, diff_method: str = 'best', backend_config: dict[str, Any] | None = None)[source]

Bases: Model, Module

A lambeq model for the quantum and hybrid quantum/classical pipeline using PennyLane circuits. It uses PyTorch as a backend for all tensor operations.

T_destination = ~T_destination
__call__(*args: Any, **kwds: Any) Any

Call self as a function.

__init__(probabilities: bool = True, normalize: bool = True, diff_method: str = 'best', backend_config: dict[str, Any] | None = None) None[source]

Initialise a PennyLaneModel instance with an empty circuit_map dictionary.

Parameters:
probabilitiesbool, default: True

Whether to use probabilities or states for the output.

backend_configdict, optional

Configuration for hardware or simulator to be used. Defaults to using the default.qubit PennyLane simulator analytically, with normalized probability outputs. Keys that can be used include ‘backend’, ‘device’, ‘probabilities’, ‘normalize’, ‘shots’, and ‘noise_model’.

add_module(name: str, module: Module | None) None

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Args:
name (str): name of the child module. The child module can be

accessed from this module using the given name

module (Module): child module to be added to the module.

apply(fn: Callable[[Module], None]) T

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also nn-init-doc).

Args:

fn (Module -> None): function to be applied to each submodule

Returns:

Module: self

Example:

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
bfloat16() T

Casts all floating point parameters and buffers to bfloat16 datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

buffers(recurse: bool = True) Iterator[Tensor]

Return an iterator over module buffers.

Args:
recurse (bool): if True, then yields buffers of this module

and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields:

torch.Tensor: module buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
call_super_init: bool = False
children() Iterator[Module]

Return an iterator over immediate children modules.

Yields:

Module: a child module

compile(*args, **kwargs)

Compile this Module’s forward using torch.compile().

This Module’s __call__ method is compiled and all arguments are passed as-is to torch.compile().

See torch.compile() for details on the arguments for this function.

cpu() T

Move all model parameters and buffers to the CPU.

Note

This method modifies the module in-place.

Returns:

Module: self

cuda(device: int | device | None = None) T

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Args:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

double() T

Casts all floating point parameters and buffers to double datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

dump_patches: bool = False
eval() T

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

This is equivalent with self.train(False).

See locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

Module: self

extra_repr() str

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

float() T

Casts all floating point parameters and buffers to float datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

forward(x: list[Diagram]) Tensor[source]

Perform default forward pass by running circuits.

In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method.

Parameters:
xlist of Diagram

The Circuits to be evaluated.

Returns:
torch.Tensor

Tensor containing model’s prediction.

classmethod from_checkpoint(checkpoint_path: StrPathT, **kwargs: Any) Model

Load the weights and symbols from a training checkpoint.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], probabilities: bool = True, normalize: bool = True, diff_method: str = 'best', backend_config: dict[str, Any] | None = None, **kwargs: Any) PennyLaneModel[source]

Build model from a list of Circuits.

Parameters:
diagramslist of Diagram

The circuit diagrams to be evaluated.

backend_configdict, optional

Configuration for hardware or simulator to be used. Defaults to using the default.qubit PennyLane simulator analytically, with normalized probability outputs. Keys that can be used include ‘backend’, ‘device’, ‘probabilities’, ‘normalize’, ‘shots’, and ‘noise_model’.

get_buffer(target: str) Tensor

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Args:
target: The fully-qualified string name of the buffer

to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

torch.Tensor: The buffer referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not a buffer

get_diagram_output(diagrams: list[Diagram]) Tensor[source]

Evaluate outputs of circuits using PennyLane.

Parameters:
diagramslist of Diagram

The Diagrams to be evaluated.

Returns:
torch.Tensor

Resulting tensor.

Raises:
ValueError

If model.weights or model.symbols are not initialised.

get_extra_state() Any

Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

object: Any extra state to store in the module’s state_dict

get_parameter(target: str) Parameter

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Args:
target: The fully-qualified string name of the Parameter

to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

torch.nn.Parameter: The Parameter referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not an nn.Parameter

get_submodule(target: str) Module

Return the submodule given by target if it exists, otherwise throw an error.

For example, let’s say you have an nn.Module A that looks like this:

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Args:
target: The fully-qualified string name of the submodule

to look for. (See above example for how to specify a fully-qualified string.)

Returns:

torch.nn.Module: The submodule referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not an nn.Module

half() T

Casts all floating point parameters and buffers to half datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

initialise_weights() None[source]

Initialise the weights of the model.

Raises:
ValueError

If model.symbols are not initialised.

ipu(device: int | device | None = None) T

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

Note

This method modifies the module in-place.

Arguments:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

load(checkpoint_path: StrPathT) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

Args:
state_dict (dict): a dict containing parameters and

persistent buffers.

strict (bool, optional): whether to strictly enforce that the keys

in state_dict match the keys returned by this module’s state_dict() function. Default: True

assign (bool, optional): When False, the properties of the tensors

in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of Default: ``False`

Returns:
NamedTuple with missing_keys and unexpected_keys fields:
  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Note:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

modules() Iterator[Module]

Return an iterator over all modules in the network.

Yields:

Module: a module in the network

Note:

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Args:

prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module

and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor): Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_children() Iterator[Tuple[str, Module]]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

(str, Module): Tuple containing a name and child module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo: Set[Module] | None = None, prefix: str = '', remove_duplicate: bool = True)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Args:

memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result

or not

Yields:

(str, Module): Tuple of name and module

Note:

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Args:

prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

remove_duplicate (bool, optional): whether to remove the duplicated

parameters in the result. Defaults to True.

Yields:

(str, Parameter): Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
register_backward_hook(hook: Callable[[Module, Tuple[Tensor, ...] | Tensor, Tuple[Tensor, ...] | Tensor], None | Tuple[Tensor, ...] | Tensor]) RemovableHandle

Register a backward hook on the module.

This function is deprecated in favor of register_full_backward_hook() and the behavior of this function will change in future versions.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_buffer(name: str, tensor: Tensor | None, persistent: bool = True) None

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict.

Buffers can be accessed as attributes using given names.

Args:
name (str): name of the buffer. The buffer can be accessed

from this module using the given name

tensor (Tensor or None): buffer to be registered. If None, then operations

that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict.

persistent (bool): whether the buffer is part of this module’s

state_dict.

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook: Callable[[T, Tuple[Any, ...], Any], Any | None] | Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Any | None], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle

Register a forward hook on the module.

The hook will be called every time after forward() has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature:

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:

hook(module, args, kwargs, output) -> None or modified output
Args:

hook (Callable): The user defined hook to be registered. prepend (bool): If True, the provided hook will be fired

before all existing forward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.modules.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False

with_kwargs (bool): If True, the hook will be passed the

kwargs given to the forward function. Default: False

always_call (bool): If True the hook will be run regardless of

whether an exception is raised while calling the Module. Default: False

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_forward_pre_hook(hook: Callable[[T, Tuple[Any, ...]], Any | None] | Callable[[T, Tuple[Any, ...], Dict[str, Any]], Tuple[Any, Dict[str, Any]] | None], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle

Register a forward pre-hook on the module.

The hook will be called every time before forward() is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
Args:

hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided hook will be fired before

all existing forward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward_pre hooks on this torch.nn.modules.Module. Note that global forward_pre hooks registered with register_module_forward_pre_hook() will fire before all hooks registered by this method. Default: False

with_kwargs (bool): If true, the hook will be passed the kwargs

given to the forward function. Default: False

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_full_backward_hook(hook: Callable[[Module, Tuple[Tensor, ...] | Tensor, Tuple[Tensor, ...] | Tensor], None | Tuple[Tensor, ...] | Tensor], prepend: bool = False) RemovableHandle

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Args:

hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided hook will be fired before

all existing backward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.modules.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_full_backward_pre_hook(hook: Callable[[Module, Tuple[Tensor, ...] | Tensor], None | Tuple[Tensor, ...] | Tensor], prepend: bool = False) RemovableHandle

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature:

hook(module, grad_output) -> tuple[Tensor] or None

The grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of grad_output in subsequent computations. Entries in grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Args:

hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided hook will be fired before

all existing backward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.modules.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module’s load_state_dict is called.

It should have the following signature::

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_module(name: str, module: Module | None) None

Alias for add_module().

register_parameter(name: str, param: Parameter | None) None

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Args:
name (str): name of the parameter. The parameter can be accessed

from this module using the given name

param (Parameter or None): parameter to be added to the module. If

None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.

requires_grad_(requires_grad: bool = True) T

Change if autograd should record operations on parameters in this module.

This method sets the parameters’ requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Args:
requires_grad (bool): whether autograd should record operations on

parameters in this module. Default: True.

Returns:

Module: self

save(checkpoint_path: StrPathT) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

set_extra_state(state: Any) None

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

Args:

state (dict): Extra state from the state_dict

share_memory() T

See torch.Tensor.share_memory_().

state_dict(*args, destination=None, prefix='', keep_vars=False)

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Args:
destination (dict, optional): If provided, the state of module will

be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

prefix (str, optional): a prefix added to parameter and buffer

names to compose the keys in state_dict. Default: ''.

keep_vars (bool, optional): by default the Tensor s

returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:
dict:

a dictionary containing a whole state of the module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
symbols: list[Symbol]
to(*args, **kwargs)

Move and/or cast the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Args:
device (torch.device): the desired device of the parameters

and buffers in this module

dtype (torch.dtype): the desired floating point or complex dtype of

the parameters and buffers in this module

tensor (torch.Tensor): Tensor whose dtype and device are the desired

dtype and device for all parameters and buffers in this module

memory_format (torch.memory_format): the desired memory

format for 4D parameters and buffers in this module (keyword only argument)

Returns:

Module: self

Examples:

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device: int | str | device | None, recurse: bool = True) T

Move the parameters and buffers to the specified device without copying storage.

Args:
device (torch.device): The desired device of the parameters

and buffers in this module.

recurse (bool): Whether parameters and buffers of submodules should

be recursively moved to the specified device.

Returns:

Module: self

train(mode: bool = True) T

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Args:
mode (bool): whether to set training mode (True) or evaluation

mode (False). Default: True.

Returns:

Module: self

training: bool
type(dst_type: dtype | str) T

Casts all parameters and buffers to dst_type.

Note

This method modifies the module in-place.

Args:

dst_type (type or string): the desired type

Returns:

Module: self

weights: torch.nn.ParameterList
xpu(device: int | device | None = None) T

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

Note

This method modifies the module in-place.

Arguments:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

zero_grad(set_to_none: bool = True) None

Reset gradients of all model parameters.

See similar function under torch.optim.Optimizer for more context.

Args:
set_to_none (bool): instead of setting to zero, set the grads to None.

See torch.optim.Optimizer.zero_grad() for details.

class lambeq.training.PytorchModel[source]

Bases: Model, Module

A lambeq model for the classical pipeline using PyTorch.

T_destination = ~T_destination
__call__(*args: Any, **kwds: Any) Any

Call self as a function.

__init__() None[source]

Initialise a PytorchModel.

add_module(name: str, module: Module | None) None

Add a child module to the current module.

The module can be accessed as an attribute using the given name.

Args:
name (str): name of the child module. The child module can be

accessed from this module using the given name

module (Module): child module to be added to the module.

apply(fn: Callable[[Module], None]) T

Apply fn recursively to every submodule (as returned by .children()) as well as self.

Typical use includes initializing the parameters of a model (see also nn-init-doc).

Args:

fn (Module -> None): function to be applied to each submodule

Returns:

Module: self

Example:

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
bfloat16() T

Casts all floating point parameters and buffers to bfloat16 datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

buffers(recurse: bool = True) Iterator[Tensor]

Return an iterator over module buffers.

Args:
recurse (bool): if True, then yields buffers of this module

and all submodules. Otherwise, yields only buffers that are direct members of this module.

Yields:

torch.Tensor: module buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
call_super_init: bool = False
children() Iterator[Module]

Return an iterator over immediate children modules.

Yields:

Module: a child module

compile(*args, **kwargs)

Compile this Module’s forward using torch.compile().

This Module’s __call__ method is compiled and all arguments are passed as-is to torch.compile().

See torch.compile() for details on the arguments for this function.

cpu() T

Move all model parameters and buffers to the CPU.

Note

This method modifies the module in-place.

Returns:

Module: self

cuda(device: int | device | None = None) T

Move all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Args:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

double() T

Casts all floating point parameters and buffers to double datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

dump_patches: bool = False
eval() T

Set the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

This is equivalent with self.train(False).

See locally-disable-grad-doc for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

Module: self

extra_repr() str

Set the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

float() T

Casts all floating point parameters and buffers to float datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

forward(x: list[Diagram]) Tensor[source]

Perform default forward pass by contracting tensors.

In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method.

Parameters:
xlist of Diagram

The Diagrams to be evaluated.

Returns:
torch.Tensor

Tensor containing model’s prediction.

classmethod from_checkpoint(checkpoint_path: StrPathT, **kwargs: Any) Model

Load the weights and symbols from a training checkpoint.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters:
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

get_buffer(target: str) Tensor

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Args:
target: The fully-qualified string name of the buffer

to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

torch.Tensor: The buffer referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not a buffer

get_diagram_output(diagrams: list[Diagram]) Tensor[source]

Contract diagrams using tensornetwork.

Parameters:
diagramslist of Diagram

The Diagrams to be evaluated.

Returns:
torch.Tensor

Resulting tensor.

Raises:
ValueError

If model.weights or model.symbols are not initialised.

get_extra_state() Any

Return any extra state to include in the module’s state_dict.

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.

Returns:

object: Any extra state to store in the module’s state_dict

get_parameter(target: str) Parameter

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Args:
target: The fully-qualified string name of the Parameter

to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

torch.nn.Parameter: The Parameter referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not an nn.Parameter

get_submodule(target: str) Module

Return the submodule given by target if it exists, otherwise throw an error.

For example, let’s say you have an nn.Module A that looks like this:

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Args:
target: The fully-qualified string name of the submodule

to look for. (See above example for how to specify a fully-qualified string.)

Returns:

torch.nn.Module: The submodule referenced by target

Raises:
AttributeError: If the target string references an invalid

path or resolves to something that is not an nn.Module

half() T

Casts all floating point parameters and buffers to half datatype.

Note

This method modifies the module in-place.

Returns:

Module: self

initialise_weights() None[source]

Initialise the weights of the model.

Raises:
ValueError

If model.symbols are not initialised.

ipu(device: int | device | None = None) T

Move all model parameters and buffers to the IPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.

Note

This method modifies the module in-place.

Arguments:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

load(checkpoint_path: StrPathT) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

Args:
state_dict (dict): a dict containing parameters and

persistent buffers.

strict (bool, optional): whether to strictly enforce that the keys

in state_dict match the keys returned by this module’s state_dict() function. Default: True

assign (bool, optional): When False, the properties of the tensors

in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. The only exception is the requires_grad field of Default: ``False`

Returns:
NamedTuple with missing_keys and unexpected_keys fields:
  • missing_keys is a list of str containing any keys that are expected

    by this module but missing from the provided state_dict.

  • unexpected_keys is a list of str containing the keys that are not

    expected by this module but present in the provided state_dict.

Note:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

modules() Iterator[Module]

Return an iterator over all modules in the network.

Yields:

Module: a module in the network

Note:

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

Args:

prefix (str): prefix to prepend to all buffer names. recurse (bool, optional): if True, then yields buffers of this module

and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.

remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.

Yields:

(str, torch.Tensor): Tuple containing the name and buffer

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_children() Iterator[Tuple[str, Module]]

Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

Yields:

(str, Module): Tuple containing a name and child module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo: Set[Module] | None = None, prefix: str = '', remove_duplicate: bool = True)

Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

Args:

memo: a memo to store the set of modules already added to the result prefix: a prefix that will be added to the name of the module remove_duplicate: whether to remove the duplicated module instances in the result

or not

Yields:

(str, Module): Tuple of name and module

Note:

Duplicate modules are returned only once. In the following example, l will be returned only once.

Example:

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Args:

prefix (str): prefix to prepend to all parameter names. recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

remove_duplicate (bool, optional): whether to remove the duplicated

parameters in the result. Defaults to True.

Yields:

(str, Parameter): Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter]

Return an iterator over module parameters.

This is typically passed to an optimizer.

Args:
recurse (bool): if True, then yields parameters of this module

and all submodules. Otherwise, yields only parameters that are direct members of this module.

Yields:

Parameter: module parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
register_backward_hook(hook: Callable[[Module, Tuple[Tensor, ...] | Tensor, Tuple[Tensor, ...] | Tensor], None | Tuple[Tensor, ...] | Tensor]) RemovableHandle

Register a backward hook on the module.

This function is deprecated in favor of register_full_backward_hook() and the behavior of this function will change in future versions.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_buffer(name: str, tensor: Tensor | None, persistent: bool = True) None

Add a buffer to the module.

This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict.

Buffers can be accessed as attributes using given names.

Args:
name (str): name of the buffer. The buffer can be accessed

from this module using the given name

tensor (Tensor or None): buffer to be registered. If None, then operations

that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict.

persistent (bool): whether the buffer is part of this module’s

state_dict.

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook: Callable[[T, Tuple[Any, ...], Any], Any | None] | Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Any | None], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle

Register a forward hook on the module.

The hook will be called every time after forward() has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature:

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:

hook(module, args, kwargs, output) -> None or modified output
Args:

hook (Callable): The user defined hook to be registered. prepend (bool): If True, the provided hook will be fired

before all existing forward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.modules.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False

with_kwargs (bool): If True, the hook will be passed the

kwargs given to the forward function. Default: False

always_call (bool): If True the hook will be run regardless of

whether an exception is raised while calling the Module. Default: False

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_forward_pre_hook(hook: Callable[[T, Tuple[Any, ...]], Any | None] | Callable[[T, Tuple[Any, ...], Dict[str, Any]], Tuple[Any, Dict[str, Any]] | None], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle

Register a forward pre-hook on the module.

The hook will be called every time before forward() is invoked.

If with_kwargs is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature:

hook(module, args) -> None or modified input

If with_kwargs is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature:

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
Args:

hook (Callable): The user defined hook to be registered. prepend (bool): If true, the provided hook will be fired before

all existing forward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward_pre hooks on this torch.nn.modules.Module. Note that global forward_pre hooks registered with register_module_forward_pre_hook() will fire before all hooks registered by this method. Default: False

with_kwargs (bool): If true, the hook will be passed the kwargs

given to the forward function. Default: False

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_full_backward_hook(hook: Callable[[Module, Tuple[Tensor, ...] | Tensor, Tuple[Tensor, ...] | Tensor], None | Tuple[Tensor, ...] | Tensor], prepend: bool = False) RemovableHandle

Register a backward hook on the module.

The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs or outputs inplace is not allowed when using backward hooks and will raise an error.

Args:

hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided hook will be fired before

all existing backward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.modules.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_full_backward_pre_hook(hook: Callable[[Module, Tuple[Tensor, ...] | Tensor], None | Tuple[Tensor, ...] | Tensor], prepend: bool = False) RemovableHandle

Register a backward pre-hook on the module.

The hook will be called every time the gradients for the module are computed. The hook should have the following signature:

hook(module, grad_output) -> tuple[Tensor] or None

The grad_output is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of grad_output in subsequent computations. Entries in grad_output will be None for all non-Tensor arguments.

For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module’s forward function.

Warning

Modifying inputs inplace is not allowed when using backward hooks and will raise an error.

Args:

hook (Callable): The user-defined hook to be registered. prepend (bool): If true, the provided hook will be fired before

all existing backward_pre hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.modules.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_load_state_dict_post_hook(hook)

Register a post hook to be run after module’s load_state_dict is called.

It should have the following signature::

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

The given incompatible_keys can be modified inplace if needed.

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()

register_module(name: str, module: Module | None) None

Alias for add_module().

register_parameter(name: str, param: Parameter | None) None

Add a parameter to the module.

The parameter can be accessed as an attribute using given name.

Args:
name (str): name of the parameter. The parameter can be accessed

from this module using the given name

param (Parameter or None): parameter to be added to the module. If

None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict.

register_state_dict_pre_hook(hook)

Register a pre-hook for the state_dict() method.

These hooks will be called with arguments: self, prefix, and keep_vars before calling state_dict on self. The registered hooks can be used to perform pre-processing before the state_dict call is made.

requires_grad_(requires_grad: bool = True) T

Change if autograd should record operations on parameters in this module.

This method sets the parameters’ requires_grad attributes in-place.

This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).

See locally-disable-grad-doc for a comparison between .requires_grad_() and several similar mechanisms that may be confused with it.

Args:
requires_grad (bool): whether autograd should record operations on

parameters in this module. Default: True.

Returns:

Module: self

save(checkpoint_path: StrPathT) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

set_extra_state(state: Any) None

Set extra state contained in the loaded state_dict.

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

Args:

state (dict): Extra state from the state_dict

share_memory() T

See torch.Tensor.share_memory_().

state_dict(*args, destination=None, prefix='', keep_vars=False)

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Args:
destination (dict, optional): If provided, the state of module will

be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

prefix (str, optional): a prefix added to parameter and buffer

names to compose the keys in state_dict. Default: ''.

keep_vars (bool, optional): by default the Tensor s

returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:
dict:

a dictionary containing a whole state of the module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
symbols: list[Symbol]
to(*args, **kwargs)

Move and/or cast the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Args:
device (torch.device): the desired device of the parameters

and buffers in this module

dtype (torch.dtype): the desired floating point or complex dtype of

the parameters and buffers in this module

tensor (torch.Tensor): Tensor whose dtype and device are the desired

dtype and device for all parameters and buffers in this module

memory_format (torch.memory_format): the desired memory

format for 4D parameters and buffers in this module (keyword only argument)

Returns:

Module: self

Examples:

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device: int | str | device | None, recurse: bool = True) T

Move the parameters and buffers to the specified device without copying storage.

Args:
device (torch.device): The desired device of the parameters

and buffers in this module.

recurse (bool): Whether parameters and buffers of submodules should

be recursively moved to the specified device.

Returns:

Module: self

train(mode: bool = True) T

Set the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Args:
mode (bool): whether to set training mode (True) or evaluation

mode (False). Default: True.

Returns:

Module: self

training: bool
type(dst_type: dtype | str) T

Casts all parameters and buffers to dst_type.

Note

This method modifies the module in-place.

Args:

dst_type (type or string): the desired type

Returns:

Module: self

weights: torch.nn.ParameterList
xpu(device: int | device | None = None) T

Move all model parameters and buffers to the XPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.

Note

This method modifies the module in-place.

Arguments:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

zero_grad(set_to_none: bool = True) None

Reset gradients of all model parameters.

See similar function under torch.optim.Optimizer for more context.

Args:
set_to_none (bool): instead of setting to zero, set the grads to None.

See torch.optim.Optimizer.zero_grad() for details.

class lambeq.training.PytorchTrainer(model: PytorchModel, loss_function: Callable[..., torch.Tensor], epochs: int, optimizer: type[torch.optim.Optimizer] = <class 'torch.optim.adamw.AdamW'>, learning_rate: float = 0.001, device: int = -1, *, optimizer_args: dict[str, Any] | None = None, evaluate_functions: Mapping[str, EvalFuncT] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: StrPathT | None = None, from_checkpoint: bool = False, verbose: str = 'text', seed: int | None = None)[source]

Bases: Trainer

A PyTorch trainer for the classical pipeline.

__init__(model: PytorchModel, loss_function: Callable[..., torch.Tensor], epochs: int, optimizer: type[torch.optim.Optimizer] = <class 'torch.optim.adamw.AdamW'>, learning_rate: float = 0.001, device: int = -1, *, optimizer_args: dict[str, Any] | None = None, evaluate_functions: Mapping[str, EvalFuncT] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: StrPathT | None = None, from_checkpoint: bool = False, verbose: str = 'text', seed: int | None = None) None[source]

Initialise a Trainer instance using the PyTorch backend.

Parameters:
modelPytorchModel

A lambeq Model using PyTorch for tensor computation.

loss_functioncallable

A PyTorch loss function from torch.nn.

epochsint

Number of training epochs.

optimizertorch.optim.Optimizer, default: torch.optim.AdamW

A PyTorch optimizer from torch.optim.

learning_ratefloat, default: 1e-3

The learning rate provided to the optimizer for training.

deviceint, default: -1

CUDA device ID used for tensor operation speed-up. A negative value uses the CPU.

optimizer_argsdict of str to Any, optional

Any extra arguments to pass to the optimizer.

evaluate_functionsmapping of str to callable, optional

Mapping of evaluation metric functions from their names. Structure [{“metric”: func}]. Each function takes the prediction “y_hat” and the label “y” as input. The validation step calls “func(y_hat, y)”.

evaluate_on_trainbool, default: True

Evaluate the metrics on the train dataset.

use_tensorboardbool, default: False

Use Tensorboard for visualisation of the training logs.

log_dirstr or PathLike, optional

Location of model checkpoints (and tensorboard log). Default is runs/**CURRENT_DATETIME_HOSTNAME**.

from_checkpointbool, default: False

Starts training from the checkpoint, saved in the log_dir.

verbosestr, default: ‘text’,

See VerbosityLevel for options.

seedint, optional

Random seed.

fit(train_dataset: Dataset, val_dataset: Dataset | None = None, log_interval: int = 1, eval_interval: int = 1, eval_mode: str = 'epoch', early_stopping_criterion: str | None = None, early_stopping_interval: int | None = None, minimize_criterion: bool = True, full_timing_report: bool = False) None

Fit the model on the training data and, optionally, evaluate it on the validation data.

Parameters:
train_datasetDataset

Dataset used for training.

val_datasetDataset, optional

Validation dataset.

log_intervalint, default: 1

Sets the intervals at which the training statistics are printed if verbose = ‘text’ (otherwise ignored). If None, the statistics are printed at the end of each epoch.

eval_intervalint, default: 1

Sets the number of epochs at which the metrics are evaluated on the validation dataset. If None, the validation is performed at the end of each epoch.

eval_modeEvalMode, default: ‘epoch’

Sets the evaluation mode. If ‘epoch’, the metrics are evaluated after multiples of eval_interval epochs. If ‘step’, the metrics are evaluated after multiples of eval_interval steps. Ignored if val_dataset is None.

early_stopping_criterionstr, optional

If specified, the value of this on val_dataset (if provided) will be used as the stopping criterion instead of the (default) validation loss.

early_stopping_intervalint, optional

If specified, training is stopped if the validation loss does not improve for early_stopping_interval validation cycles.

minimize_criterion: bool, default: True

Flag indicating if we should minimize or maximize the early stopping criterion.

full_timing_report: bool, default: False

Flag for including mean timing statistics in the logs.

Raises:
ValueError

If eval_mode is not a valid EvalMode.

load_training_checkpoint(log_dir: str | PathLike[str]) Checkpoint

Load model from a checkpoint.

Parameters:
log_dirstr or PathLike

The path to the model.lt checkpoint file.

Returns:
py:class:.Checkpoint

Checkpoint containing the model weights, symbols and the training history.

Raises:
FileNotFoundError

If the file does not exist.

model: PytorchModel
save_checkpoint(save_dict: Mapping[str, Any], log_dir: str | PathLike[str], prefix: str = '') None

Save checkpoint.

Parameters:
save_dictmapping of str to any

Mapping containing the checkpoint information.

log_dirstr or PathLike

The path where to store the model.lt checkpoint file.

prefixstr, default: ‘’

Prefix for the checkpoint file name.

training_step(batch: tuple[list[Any], Tensor]) tuple[Tensor, float][source]

Perform a training step.

Parameters:
batchtuple of list and torch.Tensor

Current batch.

Returns:
Tuple of torch.Tensor and float

The model predictions and the calculated loss.

validation_step(batch: tuple[list[Any], Tensor]) tuple[Tensor, float][source]

Perform a validation step.

Parameters:
batchtuple of list and torch.Tensor

Current batch.

Returns:
Tuple of torch.Tensor and float

The model predictions and the calculated loss.

class lambeq.training.QuantumModel[source]

Bases: Model

Quantum Model base class.

Attributes:
symbolslist of symbols

A sorted list of all Symbols occurring in the data.

weightsarray

A data structure containing the numeric values of the model parameters

__call__(*args: Any, **kwargs: Any) Any[source]

Call self as a function.

__init__() None[source]

Initialise a QuantumModel.

abstract forward(x: list[Diagram]) Any[source]

Compute the forward pass of the model using get_model_output

classmethod from_checkpoint(checkpoint_path: StrPathT, **kwargs: Any) Model

Load the weights and symbols from a training checkpoint.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters:
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

abstract get_diagram_output(diagrams: list[Diagram]) jnp.ndarray | np.ndarray[source]

Return the diagram prediction.

Parameters:
diagramslist of Diagram

The Circuits to be evaluated.

initialise_weights() None[source]

Initialise the weights of the model.

Raises:
ValueError

If model.symbols are not initialised.

load(checkpoint_path: StrPathT) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: StrPathT) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

weights: np.ndarray
class lambeq.training.QuantumTrainer(model: QuantumModel, loss_function: Callable[..., float], epochs: int, optimizer: type[Optimizer], optim_hyperparams: dict[str, float], *, optimizer_args: dict[str, Any] | None = None, evaluate_functions: Mapping[str, EvalFuncT] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: StrPathT | None = None, from_checkpoint: bool = False, verbose: str = 'text', seed: int | None = None)[source]

Bases: Trainer

A Trainer for the quantum pipeline.

__init__(model: QuantumModel, loss_function: Callable[..., float], epochs: int, optimizer: type[Optimizer], optim_hyperparams: dict[str, float], *, optimizer_args: dict[str, Any] | None = None, evaluate_functions: Mapping[str, EvalFuncT] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: StrPathT | None = None, from_checkpoint: bool = False, verbose: str = 'text', seed: int | None = None) None[source]

Initialise a Trainer using a quantum backend.

Parameters:
modelQuantumModel

A lambeq Model.

loss_functioncallable

A loss function.

epochsint

Number of training epochs

optimizerOptimizer

An optimizer of type lambeq.training.Optimizer.

optim_hyperparamsdict of str to float

The hyperparameters to be used by the optimizer.

optimizer_argsdict of str to Any, optional

Any extra arguments to pass to the optimizer.

evaluate_functionsmapping of str to callable, optional

Mapping of evaluation metric functions from their names. Structure [{“metric”: func}]. Each function takes the prediction “y_hat” and the label “y” as input. The validation step calls “func(y_hat, y)”.

evaluate_on_trainbool, default: True

Evaluate the metrics on the train dataset.

use_tensorboardbool, default: False

Use Tensorboard for visualisation of the training logs.

log_dirstr or PathLike, optional

Location of model checkpoints (and tensorboard log). Default is runs/**CURRENT_DATETIME_HOSTNAME**.

from_checkpointbool, default: False

Starts training from the checkpoint, saved in the log_dir.

verbosestr, default: ‘text’,

See VerbosityLevel for options.

seedint, optional

Random seed.

fit(train_dataset: Dataset, val_dataset: Dataset | None = None, log_interval: int = 1, eval_interval: int = 1, eval_mode: str = 'epoch', early_stopping_criterion: str | None = None, early_stopping_interval: int | None = None, minimize_criterion: bool = True, full_timing_report: bool = False) None[source]

Fit the model on the training data and, optionally, evaluate it on the validation data.

Parameters:
train_datasetDataset

Dataset used for training.

val_datasetDataset, optional

Validation dataset.

log_intervalint, default: 1

Sets the intervals at which the training statistics are printed if verbose = ‘text’ (otherwise ignored). If None, the statistics are printed at the end of each epoch.

eval_intervalint, default: 1

Sets the number of epochs at which the metrics are evaluated on the validation dataset. If None, the validation is performed at the end of each epoch.

eval_modeEvalMode, default: ‘epoch’

Sets the evaluation mode. If ‘epoch’, the metrics are evaluated after multiples of eval_interval epochs. If ‘step’, the metrics are evaluated after multiples of eval_interval steps. Ignored if val_dataset is None.

early_stopping_criterionstr, optional

If specified, the value of this on val_dataset (if provided) will be used as the stopping criterion instead of the (default) validation loss.

early_stopping_intervalint, optional

If specified, training is stopped if the validation loss does not improve for early_stopping_interval validation cycles.

minimize_criterion: bool, default: True

Flag indicating if we should minimize or maximize the early stopping criterion.

full_timing_report: bool, default: False

Flag for including mean timing statistics in the logs.

Raises:
ValueError

If eval_mode is not a valid EvalMode.

load_training_checkpoint(log_dir: str | PathLike[str]) Checkpoint

Load model from a checkpoint.

Parameters:
log_dirstr or PathLike

The path to the model.lt checkpoint file.

Returns:
py:class:.Checkpoint

Checkpoint containing the model weights, symbols and the training history.

Raises:
FileNotFoundError

If the file does not exist.

model: QuantumModel
save_checkpoint(save_dict: Mapping[str, Any], log_dir: str | PathLike[str], prefix: str = '') None

Save checkpoint.

Parameters:
save_dictmapping of str to any

Mapping containing the checkpoint information.

log_dirstr or PathLike

The path where to store the model.lt checkpoint file.

prefixstr, default: ‘’

Prefix for the checkpoint file name.

training_step(batch: tuple[list[Any], ndarray]) tuple[ndarray, float][source]

Perform a training step.

Parameters:
batchtuple of list and np.ndarray

Current batch.

Returns:
Tuple of np.ndarray and float

The model predictions and the calculated loss.

validation_step(batch: tuple[list[Any], ndarray]) tuple[ndarray, float][source]

Perform a validation step.

Parameters:
batchtuple of list and np.ndarray

Current batch.

Returns:
tuple of np.ndarray and float

The model predictions and the calculated loss.

class lambeq.training.RotosolveOptimizer(*, model: QuantumModel, loss_fn: Callable[[Any, Any], float], hyperparams: dict[str, float] | None = None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]

Bases: Optimizer

An optimizer using the Rotosolve algorithm.

Rotosolve is an optimizer for parametrized quantum circuits. It applies a shift of ±π/2 radians to each parameter, then updates the parameter based on the resulting loss. The loss function is assumed to be a linear combination of Hamiltonian measurements.

This optimizer is designed to work with ansätze that are composed of single-qubit rotations, such as the StronglyEntanglingAnsatz, Sim14Ansatz and Sim15Ansatz.

See Ostaszewski et al. for details.

__init__(*, model: QuantumModel, loss_fn: Callable[[Any, Any], float], hyperparams: dict[str, float] | None = None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None) None[source]

Initialise the Rotosolve optimizer.

Parameters:
modelQuantumModel

A lambeq quantum model.

loss_fncallable

A loss function of the form loss(prediction, labels).

hyperparamsdict of str to float, optional

Unused.

boundsArrayLike, optional

Unused.

backward(batch: tuple[Iterable[Any], ndarray]) float[source]

Perform a single backward pass.

Rotosolve does not calculate a global gradient. Instead, the parameters are updated after applying a shift of ±π/2 radians to each parameter. Therefore, there is no global step to take.

Parameters:
batchtuple of Iterable and numpy.ndarray

Current batch. Contains an Iterable of diagrams in index 0, and the targets in index 1.

Returns:
float

The calculated loss after the backward pass.

load_state_dict(state_dict: Mapping[str, Any]) None[source]

Load state of the optimizer from the state dictionary.

model: QuantumModel
static project(x: ndarray) ndarray[source]
state_dict() dict[str, Any][source]

Return optimizer states as dictionary.

step() None[source]

Perform optimisation step.

zero_grad() None

Reset the gradients to zero.

class lambeq.training.SPSAOptimizer(*, model: QuantumModel, loss_fn: Callable[[Any, Any], float], hyperparams: dict[str, Any] | None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None)[source]

Bases: Optimizer

An Optimizer using SPSA.

SPSA = Simultaneous Perturbation Stochastic Spproximations. See https://ieeexplore.ieee.org/document/705889 for details.

__init__(*, model: QuantumModel, loss_fn: Callable[[Any, Any], float], hyperparams: dict[str, Any] | None, bounds: Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | None = None) None[source]

Initialise the SPSA optimizer.

The hyperparameters must contain the following key value pairs:

hyperparams = {
    'a': A learning rate parameter, float
    'c': The parameter shift scaling factor, float
    'A': A stability constant, float
}

A good value for ‘A’ is approximately: 0.01 * Num Training steps

Parameters:
modelQuantumModel

A lambeq quantum model.

loss_fnCallable

A loss function of form loss(prediction, labels).

hyperparamsdict of str to float.

A dictionary containing the models hyperparameters.

boundsArrayLike, optional

The range of each of the model parameters.

Raises:
ValueError

If the hyperparameters are not set correctly, or if the length of bounds does not match the number of the model parameters.

backward(batch: tuple[Iterable[Any], ndarray]) float[source]

Calculate the gradients of the loss function.

The gradients are calculated with respect to the model parameters.

Parameters:
batchtuple of Iterable and numpy.ndarray

Current batch. Contains an Iterable of diagrams in index 0, and the targets in index 1.

Returns:
float

The calculated loss.

load_state_dict(state_dict: Mapping[str, Any]) None[source]

Load state of the optimizer from the state dictionary.

Parameters:
state_dictdict

A dictionary containing a snapshot of the optimizer state.

model: QuantumModel
state_dict() dict[str, Any][source]

Return optimizer states as dictionary.

Returns:
dict

A dictionary containing the current state of the optimizer.

step() None[source]

Perform optimisation step.

update_hyper_params() None[source]

Update the hyperparameters of the SPSA algorithm.

zero_grad() None

Reset the gradients to zero.

class lambeq.training.TketModel(backend_config: dict[str, Any])[source]

Bases: QuantumModel

Model based on tket.

This can run either shot-based simulations of a quantum pipeline or experiments run on quantum hardware using tket.

__call__(*args: Any, **kwargs: Any) Any

Call self as a function.

__init__(backend_config: dict[str, Any]) None[source]

Initialise TketModel based on the t|ket> backend.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration. Must include the fields backend, compilation and shots.

Raises:
KeyError

If backend_config is not provided or has missing fields.

forward(x: list[Diagram]) ndarray[source]

Perform default forward pass of a lambeq quantum model.

In case of a different datapoint (e.g. list of tuple) or additional computational steps, please override this method.

Parameters:
xlist of Diagram

The Circuits to be evaluated.

Returns:
np.ndarray

Array containing model’s prediction.

classmethod from_checkpoint(checkpoint_path: StrPathT, **kwargs: Any) Model

Load the weights and symbols from a training checkpoint.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

classmethod from_diagrams(diagrams: list[Diagram], **kwargs: Any) Model

Build model from a list of Diagrams.

Parameters:
diagramslist of Diagram

The tensor or circuit diagrams to be evaluated.

Other Parameters:
backend_configdict

Dictionary containing the backend configuration for the TketModel. Must include the fields ‘backend’, ‘compilation’ and ‘shots’.

use_jitbool, default: False

Whether to use JAX’s Just-In-Time compilation in NumpyModel.

get_diagram_output(diagrams: list[Diagram]) ndarray[source]

Return the prediction for each diagram using t|ket>.

Parameters:
diagramslist of :py:class:`~lambeq.backend.quantum.Diagram

The Circuits to be evaluated.

Returns:
np.ndarray

Resulting array.

Raises:
ValueError

If model.weights or model.symbols are not initialised.

initialise_weights() None

Initialise the weights of the model.

Raises:
ValueError

If model.symbols are not initialised.

load(checkpoint_path: StrPathT) None

Load model data from a path pointing to a lambeq checkpoint.

Checkpoints that are created by a lambeq Trainer usually have the extension .lt.

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

save(checkpoint_path: StrPathT) None

Create a lambeq Checkpoint and save to a path.

Example: >>> from lambeq import PytorchModel >>> model = PytorchModel() >>> model.save(‘my_checkpoint.lt’)

Parameters:
checkpoint_pathstr or PathLike

Path that points to the checkpoint file.

symbols: list[Symbol | SymPySymbol]
weights: np.ndarray
class lambeq.training.Trainer(model: Model, loss_function: Callable[[...], Any], epochs: int, evaluate_functions: Mapping[str, Callable[[Any, Any], Any]] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: str | PathLike[str] | None = None, from_checkpoint: bool = False, verbose: str = 'text', seed: int | None = None)[source]

Bases: ABC

Base class for a lambeq trainer.

__init__(model: Model, loss_function: Callable[[...], Any], epochs: int, evaluate_functions: Mapping[str, Callable[[Any, Any], Any]] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, log_dir: str | PathLike[str] | None = None, from_checkpoint: bool = False, verbose: str = 'text', seed: int | None = None) None[source]

Initialise a lambeq trainer.

Parameters:
modelModel

A lambeq Model.

loss_functioncallable

A loss function to compare the prediction to the true label.

epochsint

Number of training epochs.

evaluate_functionsmapping of str to callable, optional

Mapping of evaluation metric functions from their names.

evaluate_on_trainbool, default: True

Evaluate the metrics on the train dataset.

use_tensorboardbool, default: False

Use Tensorboard for visualisation of the training logs.

log_dirstr or PathLike, optional

Location of model checkpoints (and tensorboard log). Default is runs/**CURRENT_DATETIME_HOSTNAME**.

from_checkpointbool, default: False

Starts training from the checkpoint, saved in the log_dir.

verbosestr, default: ‘text’,

See VerbosityLevel for options.

seedint, optional

Random seed.

fit(train_dataset: Dataset, val_dataset: Dataset | None = None, log_interval: int = 1, eval_interval: int = 1, eval_mode: str = 'epoch', early_stopping_criterion: str | None = None, early_stopping_interval: int | None = None, minimize_criterion: bool = True, full_timing_report: bool = False) None[source]

Fit the model on the training data and, optionally, evaluate it on the validation data.

Parameters:
train_datasetDataset

Dataset used for training.

val_datasetDataset, optional

Validation dataset.

log_intervalint, default: 1

Sets the intervals at which the training statistics are printed if verbose = ‘text’ (otherwise ignored). If None, the statistics are printed at the end of each epoch.

eval_intervalint, default: 1

Sets the number of epochs at which the metrics are evaluated on the validation dataset. If None, the validation is performed at the end of each epoch.

eval_modeEvalMode, default: ‘epoch’

Sets the evaluation mode. If ‘epoch’, the metrics are evaluated after multiples of eval_interval epochs. If ‘step’, the metrics are evaluated after multiples of eval_interval steps. Ignored if val_dataset is None.

early_stopping_criterionstr, optional

If specified, the value of this on val_dataset (if provided) will be used as the stopping criterion instead of the (default) validation loss.

early_stopping_intervalint, optional

If specified, training is stopped if the validation loss does not improve for early_stopping_interval validation cycles.

minimize_criterion: bool, default: True

Flag indicating if we should minimize or maximize the early stopping criterion.

full_timing_report: bool, default: False

Flag for including mean timing statistics in the logs.

Raises:
ValueError

If eval_mode is not a valid EvalMode.

load_training_checkpoint(log_dir: str | PathLike[str]) Checkpoint[source]

Load model from a checkpoint.

Parameters:
log_dirstr or PathLike

The path to the model.lt checkpoint file.

Returns:
py:class:.Checkpoint

Checkpoint containing the model weights, symbols and the training history.

Raises:
FileNotFoundError

If the file does not exist.

save_checkpoint(save_dict: Mapping[str, Any], log_dir: str | PathLike[str], prefix: str = '') None[source]

Save checkpoint.

Parameters:
save_dictmapping of str to any

Mapping containing the checkpoint information.

log_dirstr or PathLike

The path where to store the model.lt checkpoint file.

prefixstr, default: ‘’

Prefix for the checkpoint file name.

abstract training_step(batch: tuple[list[Any], Any]) tuple[Any, float][source]

Perform a training step.

Parameters:
batchtuple of list and any

Current batch.

Returns:
Tuple of any and float

The model predictions and the calculated loss.

abstract validation_step(batch: tuple[list[Any], Any]) tuple[Any, float][source]

Perform a validation step.

Parameters:
batchtuple of list and any

Current batch.

Returns:
Tuple of any and float

The model predictions and the calculated loss.