{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# VQE example with `pytket-qujax`\n", "\n", "**Download this notebook - {nb-download}`pytket-qujax_heisenberg_vqe.ipynb`**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the docs for [qujax](https://cqcl.github.io/qujax/) and [pytket-qujax](https://cqcl.github.io/pytket-qujax/api/index.html)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from jax import numpy as jnp, random, value_and_grad, jit\n", "from pytket import Circuit\n", "from pytket.circuit.display import render_circuit_jupyter\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Let's start with a TKET circuit" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import qujax\n", "from pytket.extensions.qujax.qujax_convert import tk_to_qujax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We place barriers to stop tket automatically rearranging gates and we also store the number of circuit parameters as we'll need this later." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_circuit(n_qubits, depth):\n", " n_params = 2 * n_qubits * (depth + 1)\n", " param = jnp.zeros((n_params,))\n", " circuit = Circuit(n_qubits)\n", " k = 0\n", " for i in range(n_qubits):\n", " circuit.H(i)\n", " for i in range(n_qubits):\n", " circuit.Rx(param[k], i)\n", " k += 1\n", " for i in range(n_qubits):\n", " circuit.Ry(param[k], i)\n", " k += 1\n", " for _ in range(depth):\n", " for i in range(0, n_qubits - 1):\n", " circuit.CZ(i, i + 1)\n", " circuit.add_barrier(range(0, n_qubits))\n", " for i in range(n_qubits):\n", " circuit.Rx(param[k], i)\n", " k += 1\n", " for i in range(n_qubits):\n", " circuit.Ry(param[k], i)\n", " k += 1\n", " return circuit, n_params" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_qubits = 4\n", "depth = 2\n", "circuit, n_params = get_circuit(n_qubits, depth)\n", "render_circuit_jupyter(circuit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Now let's invoke qujax\n", "The `pytket.extensions.qujax.tk_to_qujax` function will generate a parameters -> statetensor function for us." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "param_to_st = tk_to_qujax(circuit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try it out on some random parameters values. Be aware that's JAX's random number generator requires a `jax.random.PRNGkey` every time it's called - more info on that [here](https://jax.readthedocs.io/en/latest/jax.random.html).\n", "Be aware that we still have convention where parameters are specified as multiples of $\\pi$ - that is in [0,2]." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = random.uniform(random.PRNGKey(0), shape=(n_params,), minval=0.0, maxval=2.0)\n", "statetensor = param_to_st(params)\n", "print(statetensor)\n", "print(statetensor.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that this function also has an optional second argument where an initiating `statetensor_in` can be provided. If it is not provided it will default to the all 0s state (as we use here)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can obtain statevector by simply calling `.flatten()`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "statevector = statetensor.flatten()\n", "statevector.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And sampling probabilities by squaring the absolute value of the statevector" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sample_probs = jnp.square(jnp.abs(statevector))\n", "plt.bar(jnp.arange(statevector.size), sample_probs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cost function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have our `param_to_st` function we are free to define a cost function that acts on bitstrings (e.g. maxcut) or integers by directly wrapping a function around `param_to_st`. However, cost functions defined via quantum Hamiltonians are a bit more involved.\n", "Fortunately, we can encode an Hamiltonian in JAX via the `qujax.get_statetensor_to_expectation_func` function which generates a statetensor -> expected value function for us.\n", "It takes three arguments as input\n", "- `gate_seq_seq`: A list of string (or array) lists encoding the gates in each term of the Hamiltonian. I.e. `[['X','X'], ['Y','Y'], ['Z','Z']]` corresponds to $H = aX_iX_j + bY_kY_l + cZ_mZ_n$ with qubit indices $i,j,k,l,m,n$ specified in the second argument and coefficients $a,b,c$ specified in the third argument\n", "- `qubit_inds_seq`: A list of integer lists encoding which qubit indices to apply the aforementioned gates. I.e. `[[0, 1],[0,1],[0,1]]`. Must have the same structure as `gate_seq_seq` above.\n", "- `coefficients`: A list of floats encoding any coefficients in the Hamiltonian. I.e. `[2.3, 0.8, 1.2]` corresponds to $a=2.3,b=0.8,c=1.2$ above. Must have the same length as the two above arguments." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "More specifically let's consider the problem of finding the ground state of the quantum Heisenberg Hamiltonian" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "$$\n", "\\begin{equation}\n", "H = \\sum_{i=1}^{n_\\text{qubits}-1} X_i X_{i+1} + Y_i Y_{i+1} + Z_i Z_{i+1}.\n", "\\end{equation}\n", "$$\n", "\n", "As described, we define the Hamiltonian via its gate strings, qubit indices and coefficients." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hamiltonian_gates = [[\"X\", \"X\"], [\"Y\", \"Y\"], [\"Z\", \"Z\"]] * (n_qubits - 1)\n", "hamiltonian_qubit_inds = [\n", " [int(i), int(i) + 1] for i in jnp.repeat(jnp.arange(n_qubits), 3)\n", "]\n", "coefficients = [1.0] * len(hamiltonian_qubit_inds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Gates:\\t\", hamiltonian_gates)\n", "print(\"Qubits:\\t\", hamiltonian_qubit_inds)\n", "print(\"Coefficients:\\t\", coefficients)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's get the Hamiltonian as a pure JAX function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "st_to_expectation = qujax.get_statetensor_to_expectation_func(\n", " hamiltonian_gates, hamiltonian_qubit_inds, coefficients\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check it works on the statetensor we've already generated." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "expected_val = st_to_expectation(statetensor)\n", "expected_val" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's wrap the `param_to_st` and `st_to_expectation` together to give us an all in one `param_to_expectation` cost function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "param_to_expectation = lambda param: st_to_expectation(param_to_st(param))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "param_to_expectation(params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sanity check that a different, randomly generated set of parameters gives us a new expected value." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "new_params = random.uniform(\n", " random.PRNGKey(1), shape=(n_params,), minval=0.0, maxval=2.0\n", ")\n", "param_to_expectation(new_params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exact gradients within a VQE algorithm\n", "The `param_to_expectation` function we created is a pure JAX function and outputs a scalar. This means we can pass it to `jax.grad` (or even better `jax.value_and_grad`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cost_and_grad = value_and_grad(param_to_expectation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `cost_and_grad` function returns a tuple with the exact cost value and exact gradient evaluated at the parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cost_and_grad(params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Now we have all the tools we need to design our VQE!\n", "We'll just use vanilla gradient descent with a constant stepsize" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def vqe(init_param, n_steps, stepsize):\n", " params = jnp.zeros((n_steps, n_params))\n", " params = params.at[0].set(init_param)\n", " cost_vals = jnp.zeros(n_steps)\n", " cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))\n", " for step in range(1, n_steps):\n", " cost_val, cost_grad = cost_and_grad(params[step - 1])\n", " cost_vals = cost_vals.at[step].set(cost_val)\n", " new_param = params[step - 1] - stepsize * cost_grad\n", " params = params.at[step].set(new_param)\n", " print(\"Iteration:\", step, \"\\tCost:\", cost_val, end=\"\\r\")\n", " print(\"\\n\")\n", " return params, cost_vals" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ok enough talking, let's run (and whilst we're at it we'll time it too)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "%time" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vqe_params, vqe_cost_vals = vqe(params, n_steps=250, stepsize=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the results..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(vqe_cost_vals)\n", "plt.xlabel(\"Iteration\")\n", "plt.ylabel(\"Cost\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pretty good!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `jax.jit` speedup\n", "One last thing... We can significantly speed up the VQE above via the `jax.jit`. In our current implementation, the expensive `cost_and_grad` function is compiled to [XLA](https://www.tensorflow.org/xla) and then executed at each call. By invoking `jax.jit` we ensure that the function is compiled only once (on the first call) and then simply executed at each future call - this is much faster!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cost_and_grad = jit(cost_and_grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll demonstrate this using the second set of initial parameters we randomly generated (to be sure of no caching)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "%time" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "new_vqe_params, new_vqe_cost_vals = vqe(new_params, n_steps=250, stepsize=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's some speedup!\n", "But let's also plot the training to be sure it converged correctly" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(new_vqe_cost_vals)\n", "plt.xlabel(\"Iteration\")\n", "plt.ylabel(\"Cost\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }