Identification of the unknown parameters on a pendulum model using jax

The second order differential equation (ODE) for the angle, $\theta$, of a pendulum acted on by gravity with friction is written as:

\begin{align} \frac{d^2\theta(t)}{dx^2} + b \frac{d\theta(t)}{dx} + c \sin{\theta(t)} = 0 \end{align}

where $b$ and $c$ are positive constants.

In order to solve for the angle dynamics using jax.experimental.ode.odeint, the previous second order equation must be converted to a system of two first order equations. By defining the angular velocity $\omega(t) = \theta'(t)$, the following system of ODEs is obtained:

\begin{align} \frac{d\theta(t)}{dx} &= \omega(t) \\ \frac{d\omega(t)}{dx} &= - b \omega(t) - c \sin{\theta(t)} \end{align}

This blog post will cover the scenario where we have to learn two constants of a pendulum system, $b$ and $c$, based on collected recordings of the angle values in time. The data will be generated synthetically, but in practice, this would be an experimental observation.

First thing first, lets import all the necessary packages and update default matplotlib's parameters:

Pendulum ODE system

Let y be the vecotor $[\theta, \omega]$, the system is then implemented as:

Now lets wrap the jax.experimental.ode.odeint solver in a function called model for more convenient use:

Synthetic data simulation

In practice this data would be collected experimentally, but here is generated using pend function.

White Gaussian noise is added post simulation to achieve more realistic situation, where different systematic errors in measurement (instrumental error, gross error, error due to external causes, imperfection in experimental technique or procedure, etc.) may occur.

Learning unknown constants from measurements

Other than the underlying physics of the pendulum, the noisy measurements of the pendulum angle in time are the only information we have.

We are interested in determining the (approximately) true values of both constants, $b$ and $c$. One way is to perform the gradient descent and minimize the loss function defined as the $L_2$-norm between measurements and the model output iteratively, until we reach some level of convergence or we exceed the number of iterations.

The automatic differentiation (AD) capabilities embbeded in jax are exploited rather than using standard gradient-based methods that calculate Jacobians using finite difference schemes. For more details on AD, check out The Autodiff Cookbook.

We will start the optimization with random values for $b$ and $c$ and then perform gradient descent for $2000$ epochs using the learning rate of $0.01$. Both, the number of epochs and the learning rate could be adjusted and better optimization procedure could be implemented. The idea here is just to show how simple it is to achieve great level of accuracy using minimal code.

Evaluation

Lets print out our findings and compare fitted parameters with the actual ones.

Finally, lets use these parameters, generate fitted traces and visualize the system compared to measurements.