The Adjoint Method in a Dozen Lines of JAX

The Adjoint Method is a powerful method for computing derivatives of functions involving constrained optimization. It’s been around for a long time, but recently has been popping up in machine learning, used in papers such as the Neural ODE and many others. I found it a bit hard to grasp until I implemented some toy examples in JAX. This post outlines the adjoint method, and steps through some concrete examples using JAX. I also discuss how this method interacts with modern autodiff.

The Adjoint Method

Setting

Our setting involves two vectors, xRdx and θRdθ. We want to find the gradient (with respect to θ) of a function f:Rdx×RdθR. Here, x is implicitly defined as a solution to g(x,θ)=0 for g:Rdx×RdθRdx. This is a set of dx equations.

As a completely trivial example, we could consider f(x,θ)=x+θ, subject to x=θ, which we can write as g(x,θ)=0 for g(x,θ)=xθ.

As a more complicated example, consider θ parameterising the design of a building, and x certain stresses on the foundations. The cost f is the monetary cost of the design, with penalties for high stresses. The function g is some complicated simulation computing the stresses given the design.

The Adjoint Equation

The core of the adjoint method is in the following equations. I’m going to use ab to denote the Jacobian of the function b with respect to a. This is a matrix in Rdout×din, where dout is the dimension of the output of b and din is the dimension of the input of b, and (ab)ij is the derivative of the ith coordinate of the output of b with respect to the jth coordinate of the input. It’s worth belabouring this point since the adjoint method involves a lot of algebra with Jacobians.

We first define an adjoint vector λRdx as the solution to the equation

(xg)λ=(xf).

Then, the adjoint method says we can write the gradient of f with respect to θ as

θf=λθg+(θf).

The intuition behind this is that g(x,θ)=0 implicitly defines the x which solves it, x, as a function of θ. The implicit function theorem then gives us this result. This is useful if we have to do some complicated procedure to find the x satisfying g(x,θ). If we want to compute θf, We don’t need to differentiate through that complicated procedure, and can just use the equation g directly. This allows us to even use non-differentiable methods to find x, as long as g itself is available.

Examples

Pen-and-paper

As a sanity check, let’s solve this for a simple example, f(x,θ)=x+θ, with g(x,θ)=θ2+θx. If we directly substitute x for θ in the expression for f, we can see that θf=2θ+2.

Alternatively, using the adjoint method we have to solve 1λ=1, giving λ=1, and then find θf=(2θ+1)+1=2θ+2. So we do indeed get the same result.

Circles

Let’s consider a more involved example, although still a bit contrived. As we run through this on paper, we’ll also use JAX to double-check everything. In this problem, θ=[r,φ,x^0,x^1] , f(x,θ)=xw2 for some fixed wR2, and g(x,θ)=[x0(rcosφ+x^0),x1(rsinφ+x^1)]. In other words, θ defines a circle with centre x0,x1, and a point on the circle with polar coordinates r,φ.

import jax.numpy as np

optimal_point = np.array([1., 3.])
theta_0 = rnd.normal(rnd.PRNGKey(1), shape=(4,))

def f(x, theta):
    return np.linalg.norm(x - optimal_point) ** 2

def g(x, theta):
    r, phi, x_hat_0, x_hat_1 = theta
    return np.array((x[0] - (r * np.cos(phi) + x_hat_0), x[1] - (r * np.sin(phi) + x_hat_1)))

By hand, we can see that xg is the identity matrix, and xf is the vector (2(x0w0),2(x1w1)). Note that we still have to know what x is to evaluate xf. We find this by solving g(x,θ)=0. We can check this with JAX, using the jacfwd function to get the Jacobian of a function.

from jax import jacfwd

r, phi, x_hat_0, x_hat_1 = theta_0
# Find the point solving g(x, θ) = 0
our_x = np.array([r * np.cos(phi) + x_hat_0, r * np.sin(phi) + x_hat_1])

g_x = lambda x: g(x, theta_0) # Get rid of theta-dependence
partial_x_g = jacfwd(g_x) #Return a function giving the Jacobian at a point
print(partial_x_g(our_x)) # Evaluate the Jacobian

f_x = lambda x: f(x, theta_0) # Get rid of theta-dependence
partial_x_f_fn = jacfwd(f_x)
partial_x_f = partial_x_f(our_x)
print(partial_x_f)
# Double-check this is the same as what we derived by hand:
print(np.array((2 * (our_x[0] - optimal_point[0]),
		2 * (our_x[1] - optimal_point[1]))))

Next we need to compute θg. A bit of algebra gives

θg=(cosφrsinφ10 sinφrcosφ01).

Now, we can double-check this in JAX

g_theta = lambda theta: g(our_x, theta) # Get rid of x-dependence
partial_theta_g_fn = jacfwd(g_theta)
partial_theta_g = partial_theta_g_fn(theta_0)
print(partial_theta_g)
# Double-check this is the same as what we derived by hand:
print(np.array([[-np.cos(phi), r * np.sin(phi)],
		[-np.sin(phi), -r * np.cos(phi)]]))

Now we can plug in to get the gradient derived from the adjoint method, and compare to the gradient of a function that includes solving g:

def f_g_combined(theta):
    r, phi, x_hat_0, x_hat_1 = theta
    x = np.array((r * np.cos(phi) + x_hat_0, r * np.sin(phi) + x_hat_1))
    return np.linalg.norm(x - optimal_point) ** 2

print(grad(f_g_combined)(theta_0))
print(-partial_x_f @ partial_theta_g)
Chris Cundy
Chris Cundy
Research Scientist

I’m a Research Scientist at FAR AI