Implementing GP-UCB in Jax

Here I’ll showcase the capabilities of the Jax framework to solve some interesting problems. We will be implementing the Gaussian Process-Upper Confidence Bound algorithm which is often used to solve the Bayesian Optimization problem.

In the Bayesian Optimization problem, we have some function . At each timestep we choose a point and observe a noisy realization of the function value .

We want to minimize the cumulative regret

.

The objective means that we are interested in both finding a global minimum of the $f$, but we also don’t want to incur a lot of cost while doing so. Even if we have a pretty good idea of where the minimum is, we don’t want to choose all our remaining $x$s to be there, as we could be missing an even better minimum.

An application of this could be in an idealized medical trial, where $x$ represents dosage of a drug and is the patient’s risk of adverse outcomes. We want to find a dosage which is good, but at the same time we don’t want to cause many adverse outcomes while doing so.

GP-UCB is a popular method for doing Bayesian Optimization, and it also has theoretical guarantees upper-bounding the regret (although as is often the case, these guarantees depend on knowing parameters of $f$ that you’re unlikely to know a-priori). Although in this setting we want to minimize the function instead of maximizing it and so technically we should perhaps call the algorithm the GP-Lower Confidence Bound algorithm it seems that GP-UCB is the name used to describe both the function minimization and maximization setting.

One advantage of GP-UCB is that it is very straightforward to describe. Given the points up to the current point, you fit a Gaussian process to serve as a prediction of the objective function. The Gaussian process gives us a mean function and a variance function . GP-UCB chooses the next point as

where is a parameter which is chosen to grow roughly logarithmically with . The growth of with ensures that the algorithm will eventually sample all the points infinitely often. If the algorithm doesn’t sample in a region for many iterations, then will remain the same while slowly grows until the algorithm samples at that point.

Often this algorithm is implemented by discretizing the input space into a set of points. In that setting the minimization above simply becomes a linear scan over the UCB values at all the grid points, and is very easily implemented. The computation of the GP can be handled by any of many packages (I like Scikit-Learn’s GaussianProcessRegressor). But a shortcoming with this approach is that it scales very poorly with dimension: even using a grid with 10 points per side would be completely intractable to compute in 15 dimensions.

Instead we can implicitly compute the GP objective function, and use gradient-based optimization to solve the optimization problem above. It’d be nice to use more sophisticated optimization techniques (such as Newton’s method) as they’ll converge much faster. However, computing the gradient of the GP objective is not completely straightforward, and computing the Hessian is even more annoying. With the built-in autodiff capabilities of Jax, this becomes as easy as applying the functional to our gp objective!

For simplicity, we are going to use a domain . This means that we need to use a constrained optimization method to deal with what happens if the gradient iterates take us out of the domain. We will use a log-barrier method to solve this constrained optimization problem. Since the gradient changes very rapidly near the log barrier, it’s important that we use a second-order method to ensure that the optimization remains well-conditioned.

Let’s start by defining a function to compute the GP-UCB values. The Gaussian Process mean function at a predictive point while conditioning on points in an matrix is given by

while the predictive standard deviation is given by

In both cases, the matrix is the Gram matrix given by . We will use the standard choice of a squared exponential kernel with a given length scale hyperparameter

Since we re-use the matrix and it is independent of we can save time by computing it once and passing it as an argument to our function. Writing this all out, we have

{% highlight python %} from future import print_function, division import jax.numpy as jaxnp import jax.random as jaxrand import numpy as np from jax import grad, jit, vmap import jax.lax as jaxlax

def ucb_vals(x, xs, ys, length_scale, beta, A, interrupt_flag=False): “"”Given a kernel k and a set of points xs, and objective values ys, return the values of mu(x) and sqrt(beta) * sigma(x), where mu, sigma are the mean and variance function of the GP with kernel k conditioned on xs, ys. For convenience we pass in the matrix A = (k(xs, xs) + sigma^2 I)^{-1}”””

# xs = n x d array
# x = 1 x d array
# We assume that k has been vectorized so that it can deal with a whole
# input batch and return batchwise outs
# k is a function b x d -> b, where b is the batch dimension
# here we'll only consider stationary kernels

# If we want to jit functions we can't pass functions (e.g. k) as
# arguments, so we need to set the kernel as a global function
n = xs.shape[0]
d = xs.shape[1]

# See e.g. eq 2.22 in Rasmussen and Williams for details
# mu = k(x, xs) * [k(xs, xs) + sigma^2 I]^{-1} * y
# sigma = k(x,x) - k(x, xs) * [k(xs, xs) + sigma^2 I]^{-1} * k(xs, x)

x_new = jaxnp.reshape(x, (1, d))
k_x_xs = k_vmap(xs - jaxnp.repeat(x_new, n, axis=0), length_scale)
mu_x = jaxnp.dot(jaxnp.dot(k_x_xs, A), ys)

sigma_x = k_vmap(jaxnp.zeros_like(x_new), length_scale)[0] - jaxnp.dot(jaxnp.dot(k_x_xs, A), k_x_xs.T)
return mu_x, jaxnp.sqrt(beta) * jaxnp.sqrt(sigma_x)

{% endhighlight %}