{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using Piquasso with JAX\n", "\n", "Piquasso and JAX can operate together. Just use [JaxConnector](../advanced/connectors.rst#piquasso._simulators.connectors.jax_.connector.JaxConnector) as demonstrated by the following simple example:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[8.3118737e-01 3.8421631e-02 1.1526491e-01 8.8801968e-04 5.3281188e-03\n", " 7.9921819e-03 1.3682902e-05 1.2314616e-04 3.6943855e-04 3.6943876e-04\n", " 1.5812304e-07 1.8974771e-06 8.5386482e-06 1.7077307e-05 1.2807981e-05]\n" ] } ], "source": [ "import numpy as np\n", "import piquasso as pq\n", "\n", "\n", "jax_connector = pq.JaxConnector()\n", "\n", "simulator = pq.PureFockSimulator(\n", " d=2,\n", " config=pq.Config(cutoff=5, dtype=np.float32),\n", " connector=jax_connector,\n", ")\n", "\n", "with pq.Program() as program:\n", " pq.Q() | pq.Vacuum()\n", "\n", " pq.Q(0) | pq.Displacement(r=0.43)\n", "\n", " pq.Q(0, 1) | pq.Beamsplitter(theta=np.pi / 3)\n", "\n", "state = simulator.execute(program).state\n", "\n", "print(state.fock_probabilities)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One can improve the speed of this calculation by using [jax.jit](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html):" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1. original runtime:\t0.053116798400878906 s\n", "1. compiled runtime:\t0.48732662200927734 s\n", "2. original runtime:\t0.04149889945983887 s\n", "2. compiled runtime:\t3.361701965332031e-05 s\n", "3. original runtime:\t0.03750133514404297 s\n", "3. compiled runtime:\t3.218650817871094e-05 s\n", "4. original runtime:\t0.03447985649108887 s\n", "4. compiled runtime:\t3.170967102050781e-05 s\n", "5. original runtime:\t0.033095598220825195 s\n", "5. compiled runtime:\t3.0994415283203125e-05 s\n", "6. original runtime:\t0.031104087829589844 s\n", "6. compiled runtime:\t3.123283386230469e-05 s\n", "7. original runtime:\t0.030498027801513672 s\n", "7. compiled runtime:\t3.504753112792969e-05 s\n", "8. original runtime:\t0.0323636531829834 s\n", "8. compiled runtime:\t3.123283386230469e-05 s\n", "9. original runtime:\t0.034629106521606445 s\n", "9. compiled runtime:\t3.218650817871094e-05 s\n", "10. original runtime:\t0.03313851356506348 s\n", "10. compiled runtime:\t3.218650817871094e-05 s\n" ] } ], "source": [ "import time\n", "\n", "import numpy as np\n", "import piquasso as pq\n", "\n", "from jax import jit\n", "\n", "\n", "def func(r, theta):\n", " jax_connector = pq.JaxConnector()\n", "\n", " simulator = pq.PureFockSimulator(\n", " d=2,\n", " config=pq.Config(cutoff=5, dtype=np.float32),\n", " connector=jax_connector,\n", " )\n", "\n", " with pq.Program() as program:\n", " pq.Q() | pq.Vacuum()\n", "\n", " pq.Q(0) | pq.Displacement(r=r)\n", "\n", " pq.Q(0, 1) | pq.Beamsplitter(theta=theta)\n", "\n", " return simulator.execute(program).state.fock_probabilities\n", "\n", "\n", "compiled_func = jit(func)\n", "\n", "iterations = 10\n", "\n", "for i in range(iterations):\n", " r = np.random.rand()\n", " theta = np.random.rand()\n", "\n", " start_time = time.time()\n", " func(r, theta)\n", " print(f\"{i + 1}. original runtime:\\t{time.time() - start_time} s\")\n", " start_time = time.time()\n", " compiled_func(r, theta)\n", " print(f\"{i + 1}. compiled runtime:\\t{time.time() - start_time} s\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice, that the first run of the compiled version is a bit slower at first, but the consequent calculations are significantly faster. One can also calculate the jacobian of this function, e.g., with [jax.jacfwd](https://jax.readthedocs.io/en/latest/_autosummary/jax.jacfwd.html):" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Jacobian by 'r': [-7.1482110e-01 1.4566265e-01 4.3698803e-01 7.4969521e-03\n", " 4.4981722e-02 6.7472607e-02 1.7915692e-04 1.6124130e-03\n", " 4.8372396e-03 4.8372415e-03 2.8058382e-06 3.3670065e-05\n", " 1.5151533e-04 3.0303077e-04 2.2727314e-04]\n", "Jacobian by 'theta': [ 0.0000000e+00 -1.3309644e-01 1.3309643e-01 -6.1523812e-03\n", " -1.2304768e-02 1.8457148e-02 -1.4219691e-04 -7.1098475e-04\n", " -4.2659120e-04 1.2797731e-03 -2.1910173e-06 -1.7528142e-05\n", " -3.9438339e-05 -1.7730152e-11 5.9157519e-05]\n" ] } ], "source": [ "from jax import jacfwd\n", "\n", "jacobian_func = jit(jacfwd(compiled_func, argnums=(0, 1)))\n", "\n", "jacobian = jacobian_func(0.43, np.pi / 3)\n", "\n", "print(\"Jacobian by 'r': \", jacobian[0])\n", "print(\"Jacobian by 'theta': \", jacobian[1])" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }