The Adjoint Method in a Dozen Lines of JAX
The Adjoint Method is a powerful method for computing derivatives of functions involving constrained optimization. It’s been around for a long time, but recently has been popping up in machine learning, used in papers such as the Neural ODE and many others. I found it a bit hard to grasp until I implemented some toy examples in JAX. This post outlines the adjoint method, and steps through some concrete examples using JAX. I also discuss how this method interacts with modern autodiff.
The Adjoint Method
Setting
Our setting involves two vectors,
As a completely trivial example, we could consider
As a more complicated example, consider
The Adjoint Equation
The core of the adjoint method is in the following equations.
I’m going to use
We first define an adjoint vector
Then, the adjoint method says we can write the gradient of
The intuition behind this is that
Examples
Pen-and-paper
As a sanity check, let’s solve this for a simple example,
Alternatively, using the adjoint method we have to solve
Circles
Let’s consider a more involved example, although still a bit contrived. As we run through this on paper, we’ll also use JAX to double-check everything. In this problem,
import jax.numpy as np
optimal_point = np.array([1., 3.])
theta_0 = rnd.normal(rnd.PRNGKey(1), shape=(4,))
def f(x, theta):
return np.linalg.norm(x - optimal_point) ** 2
def g(x, theta):
r, phi, x_hat_0, x_hat_1 = theta
return np.array((x[0] - (r * np.cos(phi) + x_hat_0), x[1] - (r * np.sin(phi) + x_hat_1)))
By hand, we can see that jacfwd
function to get the Jacobian of a function.
from jax import jacfwd
r, phi, x_hat_0, x_hat_1 = theta_0
# Find the point solving g(x, θ) = 0
our_x = np.array([r * np.cos(phi) + x_hat_0, r * np.sin(phi) + x_hat_1])
g_x = lambda x: g(x, theta_0) # Get rid of theta-dependence
partial_x_g = jacfwd(g_x) #Return a function giving the Jacobian at a point
print(partial_x_g(our_x)) # Evaluate the Jacobian
f_x = lambda x: f(x, theta_0) # Get rid of theta-dependence
partial_x_f_fn = jacfwd(f_x)
partial_x_f = partial_x_f(our_x)
print(partial_x_f)
# Double-check this is the same as what we derived by hand:
print(np.array((2 * (our_x[0] - optimal_point[0]),
2 * (our_x[1] - optimal_point[1]))))
Next we need to compute
Now, we can double-check this in JAX
g_theta = lambda theta: g(our_x, theta) # Get rid of x-dependence
partial_theta_g_fn = jacfwd(g_theta)
partial_theta_g = partial_theta_g_fn(theta_0)
print(partial_theta_g)
# Double-check this is the same as what we derived by hand:
print(np.array([[-np.cos(phi), r * np.sin(phi)],
[-np.sin(phi), -r * np.cos(phi)]]))
Now we can plug in to get the gradient derived from the adjoint method, and compare to the
gradient of a function that includes solving
def f_g_combined(theta):
r, phi, x_hat_0, x_hat_1 = theta
x = np.array((r * np.cos(phi) + x_hat_0, r * np.sin(phi) + x_hat_1))
return np.linalg.norm(x - optimal_point) ** 2
print(grad(f_g_combined)(theta_0))
print(-partial_x_f @ partial_theta_g)