We previously saw how we can use JAX/Flax to train a logistic regression model and how that could be extended to a simple MLP (multilayer perceptron).
Here we will use JAX to train an MLP (perhaps more accurately a fully connected feedforward network) to predict something a simpler logistic regression can’t: a sine wave. This is a simple educational example, but these models can scale to much more complicated problems, and are a key component of GPT-style models.
From an engineering perspective, statistical models are tools to solve problems. Each tool has its own properties, inductive biases and limitations, leading to different strengths and weaknesses when applied to a problem. So we’ll try to explore the properties of these models and in particular in what situations they fail.
Understanding how the model can approximate any function
The Universal Approximation Theorem states:
There exists a single layer feedforward network containing a finite number of neurons that can approximate any continuous function to arbitrary precision on a compact subset of
. (Cybenko 1989; Hornik, Stinchcombe, and White 1989; Leshno et al. 1993)
There is also an equivalent version for multiple hidden layers (Lu et al. 2017). However, existence does not guarantee learnability - how large the network needs to be is not specified, nor is how to train it.
To demonstrate the theorem, we’ll use gradient descent to train a simple model to predict a known function (plus some noise): a sine wave.
We’ll use a model with three hidden layers of 10 neurons each.
The output of each node/neuron is the result of an activation function applied to a weighted sum of the outputs of the previous layer plus a bias. Here for each neuron I use a tanh activation function, which is a sigmoidal function, very similar in shape to the logistic function used in the logistic regression model.
This is a visualization of the trained network where the weights are shown as the thickness of the lines connecting the neurons and the bias is shown as the fill color of the nodes.

The final prediction is pretty close to the true sine wave:

To understand how the model is learning the sine wave (and to get some intuition about how it can approximate any function), we’ll look at a simpler model with just one hidden layer of 10 neurons.
In this 1 hidden layer univariable case the model is now equivalent to a Generalised Additive Model with adaptive basis functions. A GAM has the following form for a single input variable
The output of each neuron is a tanh function aligned and spaced such that the downwards and upwards curves align with the sine wave, so that the sum of the outputs of the neurons locally approximates the sine wave.

This animation shows the sum step by step:

An alternative visualisation shows the contribution of each neuron as a column stacking either upwards or downwards depending on the sign of the contribution:

In this case a very simple model with just one hidden layer of 10 neurons is able to fit the sine wave well because the inductive biases of the hyperbolic tangent function (smoothness and plateauing) match the properties of the target output.
Piecewise linear inductive bias
What if the activation function was not so similar to part of the sine wave?
ReLU (rectified linear unit) is a simple piecewise linear activation function:
With the same network structure, but using ReLU as the activation function, the model does not do so well and you can clearly see the piecewise linear nature of the activation function. Actually this network is equivalent to a piecewise linear spline with free knots (Chen 2016; Hansson and Olsson 2017).

Many more linear segments are required to approximate the sine wave - it’s less efficient parameter-wise.
While the tanh model is more parameter efficient in this case, it is computationally more expensive to evaluate compared to ReLU. tanh also has a problem with vanishing gradients, as the derivative is close to zero for large values of x. That can result in gradient descent training being slow or stationary. ReLU doesn’t have this problem for large x, but for negative x the gradient is zero. An alternative is the leaky ReLU activation function which maintains a small non-zero gradient for negative x:
There are also other activation functions that avoid the non-differentiability of ReLU at 0, such as GELU and Swish, which are smooth and can ultimately result in better performance.


Increasing the capacity of the model by using four hidden layers of 32 neurons each allows a closer fit:


However, it’s still not smooth. Increasing the capacity of the model further by using five hidden layers of 64 neurons each doesn’t help:


It appears that the model is overfitting to the noise in the training data. For example, in the first negative bend the training data is slightly more negative on average than the true sine wave, and the model also predicts a more negative value there.
Training on more data helps the model to learn the underlying function better and reduces the influence of the noise.

This demonstrates the general principle that model complexity has to match the complexity of the data. If the model is too flexible it will overfit to the noise in the training data. If the model is not flexible enough it will underfit the training data.
Extrapolation and periodicity
While these types of models can interpolate any function to arbitrary precision, in this case they completely fail to predict outside the range of the training data. In the extrapolation regions the nature of the activation functions are exposed with the ReLU model linearly extrapolating and the tanh model plateauing.


The model architecture has no inductive bias to predict periodic functions.
One way to fix this is to use an activation function that is itself periodic, e.g.

Of course this is now a trivial task - we are using the solution to the problem as the activation function.
Instead let’s demonstrate with a square wave - which is periodic but not smooth.

The sine activation struggles to get perfect 90 degree corners which results in some “ringing” at the edges. This model is essentially learning a Fourier series approximation, and this overshoot at the discontinuities is known as the Gibbs phenomenon.
But it can extrapolate!

… sort of. If you look closely the prediction looks slightly shifted as the extrapolation distance increases.
Extrapolating much further out and the error is quite obvious, with what looks like a phase shift due to any slight error in the learnt parameters accumulating as extrapolation distance increases. The width of each square wave is too thin for the negative components and too wide for the positive components resulting in a large error despite being periodic.

To fix this you’d probably need to collect data for a region in this extrapolation region and retrain the model.
There is an alternative way to bake in periodicity - using Fourier features (Tancik et al. 2020).
This is where the input is first transformed by using sine and cosine functions of different frequencies, and then the model is trained on the transformed input.
The input

This actually looks better than the sine model because it doesn’t show the Gibbs phenomenon, presumably because the activation functions are still ReLU. But Fourier features are less flexible than the sine activation function model as we have to choose the frequencies in advance.
In terms of extrapolation, it can generate high error predictions outside the training range, though it doesn’t have the same phase shift issue as the sine model (perhaps because the fundamental frequency happens to match the frequency of the square wave).

A similar approach is used in LLMs to encode the position of tokens in the input sequence (Vaswani et al. 2023).
Spectral bias
Fourier features also have the advantage of overcoming the spectral bias of MLPs (the tendency for MLPs to learn low frequency functions more easily than high frequency functions) (Rahaman et al. 2019).
For example, if we add a high frequency component to the square wave, the Fourier feature network is able to learn it relatively easily, while the standard MLP basically ignores the high frequency component.


The Fourier feature network converges faster and to a lower loss than the standard MLP network.

Theoretical guarantees vs practical performance
While in theory MLPs are universal function approximators, in applied practice they have limitations. We have seen they are closely related to more traditional statistical models and share many of their limitations, such as poor extrapolation, sensitivity to the functional form assumptions (e.g. choice of activation functions), and overfitting. They can also be difficult to train, difficult to interpret, don’t have uncertianty quantification, and don’t have built in spatial or sequential information (motivating extensions like Convolutional Neural Networks and Recurrent Neural Networks). As ever, understanding the properties of your tools and how they interact with your problem - under what conditions they break - allows us to make reasonable trade-offs when forming a solution.