# The Adjoint Method in a Dozen Lines of JAX

The Adjoint Method is a powerful method for computing derivatives of functions involving constrained optimization. It’s been around for a long time, but recently has been popping up in machine learning, used in papers such as the Neural ODE and many others. I found it a bit hard to grasp until I implemented some toy examples in JAX. This post outlines the adjoint method, and steps through some concrete examples using JAX. I also discuss how this method interacts with modern autodiff.

### Setting

Our setting involves two vectors, $$x \in \mathbb{R}^{d_x}$$ and $$\theta \in \mathbb{R}^{d_\theta}$$. We want to find the gradient (with respect to $$\theta$$) of a function $$f: \mathbb{R}^{d_x} \times \mathbb{R}^{d_\theta} \to \mathbb{R}$$. Here, $$x$$ is implicitly defined as a solution to $$g(x, \theta) = \boldsymbol{0}$$ for $$g: \mathbb{R}^{d_x} \times \mathbb{R}^{d_\theta} \to \mathbb{R}^{d_x}$$. This is a set of $$d_x$$ equations.

As a completely trivial example, we could consider $$f(x, \theta) = x + \theta$$, subject to $$x = \theta$$, which we can write as $$g(x, \theta) = 0$$ for $$g(x, \theta) = x - \theta$$.

As a more complicated example, consider $$\theta$$ parameterising the design of a building, and $$x$$ certain stresses on the foundations. The cost $$f$$ is the monetary cost of the design, with penalties for high stresses. The function $$g$$ is some complicated simulation computing the stresses given the design.

The core of the adjoint method is in the following equations. I’m going to use $$\partial_a b$$ to denote the Jacobian of the function $$b$$ with respect to $$a$$. This is a matrix in $$\mathbb{R}^{d_{\text{out}}\times d_{\text{in}}}$$, where $$d_\text{out}$$ is the dimension of the output of $$b$$ and $$d_{\text{in}}$$ is the dimension of the input of $$b$$, and $$(\partial _a b)_{ij}$$ is the derivative of the $$i$$th coordinate of the output of $$b$$ with respect to the $$j$$th coordinate of the input. It’s worth belabouring this point since the adjoint method involves a lot of algebra with Jacobians.

We first define an adjoint vector $$\lambda \in \mathbb{R}^{d_x}$$ as the solution to the equation

\begin{align*} (\partial_x g)^\top \lambda = -(\partial_x f)^\top. \end{align*}

Then, the adjoint method says we can write the gradient of $$f$$ with respect to $$\theta$$ as

\begin{align*} \nabla_\theta f = \lambda^\top \partial_\theta g + (\partial_\theta f)^\top. \end{align*}

The intuition behind this is that $$g(x, \theta) = 0$$ implicitly defines the $$x$$ which solves it, $$x^*$$, as a function of $$\theta$$. The implicit function theorem then gives us this result. This is useful if we have to do some complicated procedure to find the $$x$$ satisfying $$g(x, \theta)$$. If we want to compute $$\nabla_\theta f$$, We don’t need to differentiate through that complicated procedure, and can just use the equation $$g$$ directly. This allows us to even use non-differentiable methods to find $$x$$, as long as $$g$$ itself is available.

## Examples

### Pen-and-paper

As a sanity check, let’s solve this for a simple example, $$f(x, \theta) = x + \theta$$, with $$g(x, \theta) = \theta^2 + \theta - x$$. If we directly substitute $$x$$ for $$\theta$$ in the expression for $$f$$, we can see that $$\nabla_\theta f = 2\theta + 2$$.

Alternatively, using the adjoint method we have to solve $$-1\lambda = -1$$, giving $$\lambda = 1$$, and then find $$\nabla_\theta f = (2\theta + 1) + 1 = 2\theta + 2$$. So we do indeed get the same result.

### Circles

Let’s consider a more involved example, although still a bit contrived. As we run through this on paper, we’ll also use JAX to double-check everything. In this problem, $$\theta = [r, \varphi, {\hat x}_0, {\hat x}_1]$$ , $$f(x, \theta) = \|x - w\|^2$$ for some fixed $$w\in\mathbb{R}^2$$, and $$g(x, \theta) = \left[x_0 - (r\cos\varphi + {\hat x}_0), x_1 - (r\sin\varphi + {\hat x}_1)\right]$$. In other words, $$\theta$$ defines a circle with centre $$x_0, x_1$$, and a point on the circle with polar coordinates $$r, \varphi$$.

import jax.numpy as np

optimal_point = np.array([1., 3.])
theta_0 = rnd.normal(rnd.PRNGKey(1), shape=(4,))

def f(x, theta):
return np.linalg.norm(x - optimal_point) ** 2

def g(x, theta):
r, phi, x_hat_0, x_hat_1 = theta
return np.array((x[0] - (r * np.cos(phi) + x_hat_0), x[1] - (r * np.sin(phi) + x_hat_1)))



By hand, we can see that $$\partial_x g$$ is the identity matrix, and $$\partial_x f$$ is the vector $$(2(x_0 - w_0), 2(x_1 - w_1))$$. Note that we still have to know what $$x$$ is to evaluate $$\partial_x f$$. We find this by solving $$g(x, \theta) = 0$$. We can check this with JAX, using the jacfwd function to get the Jacobian of a function.

from jax import jacfwd

r, phi, x_hat_0, x_hat_1 = theta_0
# Find the point solving g(x, θ) = 0
our_x = np.array([r * np.cos(phi) + x_hat_0, r * np.sin(phi) + x_hat_1])

g_x = lambda x: g(x, theta_0) # Get rid of theta-dependence
partial_x_g = jacfwd(g_x) #Return a function giving the Jacobian at a point
print(partial_x_g(our_x)) # Evaluate the Jacobian

f_x = lambda x: f(x, theta_0) # Get rid of theta-dependence
partial_x_f_fn = jacfwd(f_x)
partial_x_f = partial_x_f(our_x)
print(partial_x_f)
# Double-check this is the same as what we derived by hand:
print(np.array((2 * (our_x[0] - optimal_point[0]),
2 * (our_x[1] - optimal_point[1]))))


Next we need to compute $$\partial_\theta g$$. A bit of algebra gives

\begin{align*} \partial_\theta g = \begin{pmatrix} -\cos \varphi & r\sin \varphi & -1 & 0 \ -\sin \varphi & -r\cos \varphi & 0 & -1 \end{pmatrix}. \end{align*}

Now, we can double-check this in JAX

g_theta = lambda theta: g(our_x, theta) # Get rid of x-dependence
partial_theta_g_fn = jacfwd(g_theta)
partial_theta_g = partial_theta_g_fn(theta_0)
print(partial_theta_g)
# Double-check this is the same as what we derived by hand:
print(np.array([[-np.cos(phi), r * np.sin(phi)],
[-np.sin(phi), -r * np.cos(phi)]]))



Now we can plug in to get the gradient derived from the adjoint method, and compare to the gradient of a function that includes solving $$g$$:

def f_g_combined(theta):
r, phi, x_hat_0, x_hat_1 = theta
x = np.array((r * np.cos(phi) + x_hat_0, r * np.sin(phi) + x_hat_1))
return np.linalg.norm(x - optimal_point) ** 2