Pseudo-Invertible networks using JAX

This post aims at training invertible neural networks. To be specific, the goal is to learn two functions $f_\theta, g_\phi$ which are $\mathbb{R} \to \mathbb{R}$, so that we have $f_\theta \approx g_\phi^{-1}$.

This relationship between $f_\theta$ and $g_\phi$ means that at any two points $x,y$ we should have

$$ \begin{align*} y &\approx f_\theta \circ g_\phi (y) \\ x &\approx g_\phi \circ f_\theta (x) \\ \end{align*} $$

This condition can be enforced through the constraint $\Omega(\mu, \lambda)$ defined as

$$ \Omega(\mu, \lambda) = \mu\,\mathbb{E}_Y\lVert Y- f_\theta \circ g_\phi (Y) \rVert^2 + \lambda \,\mathbb{E}_X\lVert X- g_\phi\circ f_\theta (X) \rVert^2$$

This expectation w.r.t inputs $x, y$ is taken over the distribution of the inputs, so that in practice we obtain discrete sums over the dataset $\mathcal{D} = {(x_i, y_i)}_{i=1}^n$. A downside of this approach is that the functions are approximately invertible only around points present in $\mathcal{D}$.

This $\Omega$ constraint is added to the base $L_2$ goodness of fit constraint $\mathcal{L}$ for the two networks: $$ \mathcal{L}(\mathcal{D}) = \dfrac{1}{n}\sum_{i=1}^n\lVert f_\theta(x_i) – y_i\rVert^2 + \dfrac{1}{n}\sum_{i=1}^n\lVert g_\phi(y_i) – x_i\rVert^2$$

So that our final loss is $\tilde{\mathcal{L}}(\mathcal{D}, \lambda, \mu) = \mathcal{L}(\mathcal{D}) + \Omega(\mu, \lambda)$.

Neural Networks using Haiku

The default library to define and train neural networks with JAX is Haiku. Declaring my two networks $f_\theta$ and $g_\phi$ using Haiku is fairly straightforward:

import haiku as hk

def two_layers_net(width: int = 30,
                   output_dim: int = 1
                   ) -> hk.Module:
    A basic two layer network with ReLU activations
    network = hk.Sequential([
        hk.Linear(width), jax.nn.relu,
        hk.Linear(width), jax.nn.relu,
    return network

However, in order to evaluate my networks at specific combinations of $\theta, \phi$ for specific points $x,y$, one needs to do something that felt strange to me: initialize and evaluate the network in the same function:

def net_evaluate(X: array,
                 Y: array,
                 width: int = 30,
                 ) -> Tuple[array]:
    Evaluates the two networkx on data `X`, `Y`.
    output_dim = Y.shape[1]
    net_frwd = two_layers_net(width, output_dim)
    net_bkwd = two_layers_net(width, output_dim)
    Y_hat = net_frwd(X)
    X_hat = net_bkwd(Y)
    inv_X = net_bkwd(Y_hat)
    inv_Y = net_frwd(X_hat)
    return X_hat, Y_hat, inv_X, inv_Y

Here, array corresponds to a wrapper around the basic JAX array type, net_frwd corresponds to $f_\theta$ and net_bkwd to $g_\phi$. We also use the variables $\hat{y} = f_\theta(x)$, $\hat{x} = g_\phi(y)$. Finally, X_inv and Y_inv correspond to the outputs of the compositions $f_\theta \circ g_\phi $ and $g_\phi \circ f_\theta $.

In order to transform out network evaluation function into a “pure function”, we need to make the following call:

net = hk.without_apply_rng(hk.transform(net_evaluate))

This transforms our evaluation function into a function which explicitly depends on the parameters $\theta, \phi$ and the datapoints $x,y$.

Training phase

Following the MNIST Haiku example we use optax to optimize over the parameters of the two networks. The rest of the code can be found on github. The point cloud we're fitting is consists of $10^3$ points sampled from a post-nonlinear model $y = h_1( h_2(x) + \varepsilon)$ generated using random cubic splines:

scatter xy

We train and plot the learned networks for different values of $\mu, \lambda$. That is, we decide whether we prioritize invertibility of goodnes of fit using the hyperparameters.

Equal priority

We set $\lambda = \mu = 1$. This means the invertibility constraint $\Omega$ is as important as the $L_2$ fit $\mathcal{L}$.

We train the two networks using optax's Adam optimizer optax.adam(...) with a learning rate of $10^{-3}$ for $500$ epochs. The training loss $\tilde{\mathcal{L}}$ decreases smoothly:

training loss

The networks $f_\theta, g_\phi$ do not perfectly fit our scatter, because they must be approximately invertible. In the following figure, we plot $f,g$ by swapping the axes; the first plot illustrates the $x,y$ scatter and $f$'s fit, and the second the $y,x$ scatter together with $g$'s fit:

two maps

An interesting plot is the comparison between the inputs $x,y$ and their image under the almost identity maps $g\circ f \approx f \circ g \approx \mathrm{id}$.

invertibility of the maps

The dashed line corresponds to a perfect identity, and the blue and red maps correspond to $g\circ f$ and $f\circ g$. We can clearly see deviations at different intervals.

Fit priority

Let' set $\lambda = \mu = 0.1$. This means the goodness of fit is ten times more important compared to the invertibility constraint. We obtain the fits and invertibilities:

fit with 1 to 10 constraint ratio

When we let $\lambda, \mu \to 0$ we have gradually better fits and worse invertibility: for $\lambda = \mu = 0.01$, we get

fit with 1 to 100 constraint ratio

Invertibility priority

Let's do the same in reverse, setting $\mu = \lambda = 10$ we obtain the fits

10 to 1 constraint ratio

With $\lambda, \mu \to \infty$ we care less and less about the quality of the $L_2$ fits are only about having invertible functions. For $\lambda = \mu = 100$ we obtain:

100 to 1 constraint ratio

The fits are terrible, as expected: the functions are almost linear, and do not even capture linear trends in the data.

Unconstrained fits

As a baseline, fitting the two networks separately by optimizing only the $L_2$ loss yields the following fits:

no constraint fit

Our model seems too simple. Let's change the activations to SiLU functions. This is easily done via JAX's jax.nn.silu. The results are a bit better:

no constraint silu fit