jax
¶The second order differential equation (ODE) for the angle, $\theta$, of a pendulum acted on by gravity with friction is written as:
\begin{align} \frac{d^2\theta(t)}{dx^2} + b \frac{d\theta(t)}{dx} + c \sin{\theta(t)} = 0 \end{align}where $b$ and $c$ are positive constants.
In order to solve for the angle dynamics using jax.experimental.ode.odeint
, the previous second order equation must be converted to a system of two first order equations. By defining the angular velocity $\omega(t) = \theta'(t)$, the following system of ODEs is obtained:
This blog post will cover the scenario where we have to learn two constants of a pendulum system, $b$ and $c$, based on collected recordings of the angle values in time. The data will be generated synthetically, but in practice, this would be an experimental observation.
First thing first, lets import all the necessary packages and update default matplotlib
's parameters:
import jax.numpy as np
from jax import value_and_grad, jit
from jax import random
from jax.experimental.ode import odeint
import matplotlib.pyplot as plt
from tqdm.auto import trange
%config InlineBackend.figure_format = 'retina'
plt.style.use('seaborn-paper')
plt.rcParams.update({
'font.family': 'serif',
'font.size': 12,
'axes.labelsize': 16,
'axes.titlesize': 16,
'grid.linewidth': 0.7,
'legend.fontsize': 12,
'xtick.labelsize': 12,
'ytick.labelsize': 12,
'lines.linewidth': 2.5,
'lines.markersize': 8,
'lines.markeredgecolor': 'k',
'lines.markeredgewidth': 1.0
})
Let y
be the vecotor $[\theta, \omega]$, the system is then implemented as:
def pend(y, t, b, c):
theta, omega = y
return np.hstack([
omega,
-b*omega - c*np.sin(theta)
])
Now lets wrap the jax.experimental.ode.odeint
solver in a function called model
for more convenient use:
def model(y0, t, params):
return odeint(pend, y0, t, *params)
In practice this data would be collected experimentally, but here is generated using pend
function.
White Gaussian noise is added post simulation to achieve more realistic situation, where different systematic errors in measurement (instrumental error, gross error, error due to external causes, imperfection in experimental technique or procedure, etc.) may occur.
# model constants
b = 0.25
c = 5.0
true_params = np.asarray([b, c])
# initial conditions: nearly vertical pendulum, initially at rest
y0 = [np.pi - 0.1, 0.0]
# simulation time in seconds
t = np.linspace(0, 10, 101)
# simulation
data = model(y0, t, (b, c))
theta, omega = data
# add noise to angle measurements: training data
key = random.PRNGKey(0)
noisy_theta = theta + 0.3 * theta * random.normal(key, shape=theta.shape)
# visualization
fig = plt.figure(figsize=(9, 5))
ax = fig.add_subplot()
ax.plot(t, theta, 'm-', label=r'$\theta(t)$')
ax.plot(t, omega, 'c-', label=r'$\omega(t)$')
ax.plot(t, noisy_theta, 'mo', label='training data')
ax.set_xlabel('$t$ [ms]')
ax.legend(bbox_to_anchor=(0.8, 1.15), ncol=3, frameon=True, edgecolor='k')
ax.grid()
plt.show()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Other than the underlying physics of the pendulum, the noisy measurements of the pendulum angle in time are the only information we have.
We are interested in determining the (approximately) true values of both constants, $b$ and $c$. One way is to perform the gradient descent and minimize the loss function defined as the $L_2$-norm between measurements and the model output iteratively, until we reach some level of convergence or we exceed the number of iterations.
The automatic differentiation (AD) capabilities embbeded in jax
are exploited rather than using standard gradient-based methods that calculate Jacobians using finite difference schemes.
For more details on AD, check out The Autodiff Cookbook.
def mse(y_true, y_pred):
return np.mean((y_true - y_pred) ** 2)
def loss_fn(params, data, y0, t):
pred = model(y0, t, params)
return mse(data, pred[0])
# where the magic happens
value_and_grad_loss_fn = jit(value_and_grad(loss_fn))
We will start the optimization with random values for $b$ and $c$ and then perform gradient descent for $2000$ epochs using the learning rate of $0.01$. Both, the number of epochs and the learning rate could be adjusted and better optimization procedure could be implemented. The idea here is just to show how simple it is to achieve great level of accuracy using minimal code.
key, subkey = random.split(key)
rand_params = random.uniform(key, shape=(2, ))
lr = 0.01
num_epochs = 2000
loss_value_list = []
loss_grad_list = []
opt_params = rand_params
tr = trange(num_epochs, desc='Loss', leave=True)
for epoch in tr:
loss_value, loss_grad = value_and_grad_loss_fn(
opt_params, noisy_theta, y0, t
)
loss_value_list.append(loss_value)
loss_grad_list.append(loss_grad)
opt_params = opt_params - lr * loss_grad
if epoch % 100 == 0:
tr.set_description(f'[Loss = {loss_value:.5f}]')
Lets print out our findings and compare fitted parameters with the actual ones.
header = '\tActual value | Initial guess | Graddesc fit'
print(header)
print('\t', '-' * len(header), sep='')
print(f'b\t {b}\t\t{rand_params[0]:.5f}\t\t{opt_params[0]:.5f}')
print(f'c\t {c}\t\t{rand_params[1]:.5f}\t\t{opt_params[1]:.5f}')
Actual value | Initial guess | Graddesc fit -------------------------------------------- b 0.25 0.48026 0.23510 c 5.0 0.23837 5.03360
Finally, lets use these parameters, generate fitted traces and visualize the system compared to measurements.
fit = model(y0, t, opt_params)
fig = plt.figure(figsize=(9, 5))
ax = fig.add_subplot()
ax.plot(t, theta, 'm-', label=r'$\theta(t)$')
ax.plot(t, noisy_theta, 'mo', markevery=2, label='training data')
ax.plot(t, fit[0], 'k--', label=r'$\tilde \theta(t)$')
ax.plot(t, omega, 'c-', label=r'$\omega(t)$')
ax.plot(t, fit[1], 'k-.', label=r'$\tilde\omega(t)$')
ax.set_xlabel(r'$t$ [ms]')
ax.legend(bbox_to_anchor=(0.7155, 1.3), ncol=2, frameon=True, edgecolor='k')
ax.grid()
plt.show()