{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Symbolic circuits with `pytket-qujax`\n", "\n", "**Download this notebook - {nb-download}`pytket-qujax_qaoa.ipynb`**\n", "\n", "In this notebook we will show how to manipulate symbolic circuits with the `pytket-qujax` extension. In particular, we will consider a QAOA and an Ising Hamiltonian." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the docs for [qujax](https://cqcl.github.io/qujax/) and [pytket-qujax](inv:pytket-qujax:*:doc#index)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pytket import Circuit\n", "from pytket.circuit.display import render_circuit_jupyter\n", "from jax import numpy as jnp, random, value_and_grad, jit\n", "from sympy import Symbol\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import qujax\n", "from pytket.extensions.qujax import tk_to_qujax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## QAOA\n", "The Quantum Approximate Optimization Algorithm (QAOA), first introduced by [Farhi et al.](https://arxiv.org/pdf/1411.4028), is a quantum variational algorithm used to solve optimization problems. It consists of a unitary $U(\\beta, \\gamma)$ formed by alternate repetitions of $U(\\beta)=e^{-i\\beta H_B}$ and $U(\\gamma)=e^{-i\\gamma H_P}$, where $H_B$ is the mixing Hamiltonian and $H_P$ the problem Hamiltonian. The goal is to find the optimal parameters that minimize $H_P$.\n", "Given a depth $d$, the expression of the final unitary is $U(\\beta, \\gamma) = U(\\beta_d)U(\\gamma_d)\\cdots U(\\beta_1)U(\\gamma_1)$. Notice that for each repetition the parameters are different.\n", "\n", "## Problem Hamiltonian\n", "QAOA uses a problem dependent ansatz. Therefore, we first need to know the problem that we want to solve. In this case we will consider an Ising Hamiltonian with only $Z$ interactions. Given a set of pairs (or qubit indices) $E$, the problem Hamiltonian will be:\n", "\n", "$$\n", "\\begin{equation}\n", "H_P = \\sum_{(i, j) \\in E}\\alpha_{ij}Z_iZ_j,\n", "\\end{equation}\n", "$$\n", "\n", "where $\\alpha_{ij}$ are the coefficients.\n", "Let's build our problem Hamiltonian with random coefficients and a set of pairs for a given number of qubits:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_qubits = 4\n", "hamiltonian_qubit_inds = [(0, 1), (1, 2), (0, 2), (1, 3)]\n", "hamiltonian_gates = [[\"Z\", \"Z\"]] * (len(hamiltonian_qubit_inds))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that in order to use the random package from jax we first need to define a seeded key" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "seed = 13\n", "key = random.PRNGKey(seed)\n", "coefficients = random.uniform(key, shape=(len(hamiltonian_qubit_inds),))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Gates:\\t\", hamiltonian_gates)\n", "print(\"Qubits:\\t\", hamiltonian_qubit_inds)\n", "print(\"Coefficients:\\t\", coefficients)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Variational Circuit\n", "Before constructing the circuit, we still need to select the mixing Hamiltonian. In our case, we will be using $X$ gates in each qubit, so $H_B = \\sum_{i=1}^{n}X_i$, where $n$ is the number of qubits. Notice that the unitary $U(\\beta)$, given this mixing Hamiltonian, is an $X$ rotation in each qubit with angle $\\beta$.\n", "As for the unitary corresponding to the problem Hamiltonian, $U(\\gamma)$, it has the following form:\n", "\n", "$$\n", "\\begin{equation}\n", "U(\\gamma)=\\prod_{(i, j) \\in E}e^{-i\\gamma\\alpha_{ij}Z_i Z_j}\n", "\\end{equation}\n", "$$\n", "\n", "The operation $e^{-i\\gamma\\alpha_{ij}Z_iZ_j}$ can be performed using two CNOT gates with qubit $i$ as control and qubit $j$ as target and a $Z$ rotation in qubit $j$ in between them, with angle $\\gamma\\alpha_{ij}$.\n", "Finally, the initial state used, in general, with the QAOA is an equal superposition of all the basis states. This can be achieved adding a first layer of Hadamard gates in each qubit at the beginning of the circuit." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With all the building blocks, let's construct the symbolic circuit using tket. Notice that in order to define the parameters, we use the ```Symbol``` object from the `sympy` package. More info can be found manual section on [symbolic circuits](https://docs.quantinuum.com/tket/user-guide/manual/manual_circuit.html#symbolic-circuits). In order to later convert the circuit to qujax, we need to return the list of symbolic parameters as well." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def qaoa_circuit(n_qubits, depth):\n", " circuit = Circuit(n_qubits)\n", " p_keys = []\n", "\n", " # Initial State\n", " for i in range(n_qubits):\n", " circuit.H(i)\n", " for d in range(depth):\n", " # Hamiltonian unitary\n", " gamma_d = Symbol(f\"γ_{d}\")\n", " for index in range(len(hamiltonian_qubit_inds)):\n", " pair = hamiltonian_qubit_inds[index]\n", " coef = coefficients[index]\n", " circuit.CX(pair[0], pair[1])\n", " circuit.Rz(gamma_d * coef, pair[1])\n", " circuit.CX(pair[0], pair[1])\n", " circuit.add_barrier(range(0, n_qubits))\n", " p_keys.append(gamma_d)\n", "\n", " # Mixing unitary\n", " beta_d = Symbol(f\"β_{d}\")\n", " for i in range(n_qubits):\n", " circuit.Rx(beta_d, i)\n", " p_keys.append(beta_d)\n", " return circuit, p_keys" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "depth = 3\n", "circuit, keys = qaoa_circuit(n_qubits, depth)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "keys" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check the circuit:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "render_circuit_jupyter(circuit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Now for `qujax`\n", "The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us. However, in order to convert a symbolic circuit we first need to define the `symbol_map`. This object maps each symbol key to their corresponding index. In our case, since the object `keys` contains the symbols in the correct order, we can simply construct the dictionary as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "symbol_map = {keys[i]: i for i in range(len(keys))}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "symbol_map" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we invoke the `tk_to_qujax` with both the circuit and the symbolic map." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "param_to_st = tk_to_qujax(circuit, symbol_map=symbol_map)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we also construct the expectation map using the problem Hamiltonian via qujax:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "st_to_expectation = qujax.get_statetensor_to_expectation_func(\n", " hamiltonian_gates, hamiltonian_qubit_inds, coefficients\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "param_to_expectation = lambda param: st_to_expectation(param_to_st(param))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training process\n", "We construct a function that, given a parameter vector, returns the value of the cost function and the gradient.\n", "We also `jit` to avoid recompilation, this means that the expensive `cost_and_grad` function is compiled once into a very fast XLA (C++) function which is then executed at each iteration. Alternatively, we could get the same speedup by replacing our `for` loop with `jax.lax.scan`. You can read more about JIT compilation in the [JAX documentation](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cost_and_grad = jit(value_and_grad(param_to_expectation))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the training process we'll use vanilla gradient descent with a constant stepsize:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "seed = 123\n", "key = random.PRNGKey(seed)\n", "init_param = random.uniform(key, shape=(len(symbol_map),))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_steps = 150\n", "stepsize = 0.01" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "param = init_param" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cost_vals = jnp.zeros(n_steps)\n", "cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for step in range(1, n_steps):\n", " cost_val, cost_grad = cost_and_grad(param)\n", " cost_vals = cost_vals.at[step].set(cost_val)\n", " param = param - stepsize * cost_grad\n", " print(\"Iteration:\", step, \"\\tCost:\", cost_val, end=\"\\r\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualise the gradient descent" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(cost_vals)\n", "plt.xlabel(\"Iteration\")\n", "plt.ylabel(\"Cost\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }