JAX Calculus: antiderivatives and derivatives using JAX

This post's aim is to approximate the fundamental theorem of calculus, that is: the function $F$ defined as

$$ F ( x ) \triangleq \int_a^x f ( t ) \mathrm{d} t$$

satisfies the property $F' ( x ) = f$ (here $a$ is some constant). A corollary that we shall investigate is that $f(x) = \int_a^x f' ( t ) \mathrm{d} t $.

Test functions

We'll construct compositions using scipy's cubic splines fit to random knots. We create four base random splines which we denote $g_i$.

Our resulting test functions $\phi_i$ are:

In the figure below, the four splines $g_i$ are on the left in blue, and their transformation $\phi_i$ on the right in red:

test functions

Finally, we create a fifth function $\phi_4 (x) = \phi_0(x)\phi_1(x) \phi_2(x) – \phi_3(x)$. This function is fairly complex, as can be seen in the figure below:

fifth function

Integral approximation

Most integral approximation for the quantity $\int_a^b f( t )\mathrm{d}t$ work by evaluating the integrand $f$ on a grid $\{x_i\}_{i=1}^n\subset [a,b]$ so that for all $i$ $x_i < x_{i+1}$. The continuous area described by the points $(x, f(x))$ becomes $n-1$ discrete areas $a_i = a(x_i , x_{i+1})$ so that

$$ \int_a^b f( t )\mathrm{d}t \approx \sum_{i=1}^{n-1} a_i $$

The areas are often defined using the sub-interval lengths $\Delta x_i = x_{i+1} – x_i$ and the function values $f(x_i), f(x_{i+1})$. In our case, we'll use equally spaced points so hat $h = \Delta x = \Delta x_i = (b-a)/n$.

Trapezoïdal method

On the sub-interval $[x_i, x_{i+1}]$ the area is approximated by a trapezoïd with area $h\cdot (f(x_i) + f(x_{i+1}))/2$. The resulting discrete sum is \begin{align*} \int_a^b f( t )\mathrm{d}t &\approx \dfrac{h}{2}\sum_{i=1}^{n-1} \bigg[ f(x_i) + f(x_{i+1}) \bigg] \\ &= \dfrac{h}{2}\bigg[f(x_1) + f(x_n) + 2\sum_{i=2}^{n-2}f(x_i)\bigg] \\ \end{align*}

Simpson's rule

Instead of computing areas on sub-intervals $[x_i, x_{i+1}]$ using two points, we can improve the accuracy of the method by considering a 3-tuple of points $(x_{i-1}, x_i, x_{i+1})$. This does mean we increase the total number of points $n$, but it means we group the $n$ points 3 by 3 instead of grouping them in pairs. The resulting formula is, for even $n$,

$$ \int_a^b f( t )\mathrm{d}t \approx \frac{h}{3}\sum_{i=2}^{n/2}\bigg[f(x_{2i-2})+4f(x_{2i-1})+f(x_{2i})\bigg] $$