For reference, a derivation of a variational approximation for linear regression following (Ormerod 2008) and (Ormerod and Wand 2010).
Recall the evidence lower bound is given by \[ \begin{aligned} \ln p(y|\theta) &\geq \mathcal{L}(y|\theta;q) \\ &= \mathbb E_q[\ln p(y,\theta) - \ln q(\theta)] \\ &= \mathbb E_q[\ln p(y|\theta)] + \mathbb E_q[\ln p(\theta)] + \mathbb H_q[\theta]. \end{aligned} \]
The model we are interested in is \[ \begin{aligned} y|\beta,\sigma^2 &\sim N(X\beta, \sigma^2 I_n) \\ \beta &\sim N(\mu_0,\Sigma_0) \\ \sigma^2 &\sim IG(a_0, b_0). \end{aligned} \]
Suppose we specify a parametric approximating density \(q(\beta,\sigma^2;\xi) = q(\beta;\xi_\beta)q(\sigma^2;\xi_\sigma)\) where \[ \begin{aligned} q(\beta;\xi_\beta) &= N(\beta|\mu_\beta, \Sigma_\beta) \\ q(\sigma^2;\xi_\sigma) &= IG(\sigma^2|a_{\sigma^2}, b_{\sigma^2}). \end{aligned} \]
We then have the following lower bound \[ \mathcal{L}(y|\theta;q) = \mathbb E_q[\ln p(y|\beta,\sigma^2)] + \mathbb E_q[\ln p(\beta)] + \mathbb E_q[\ln p(\sigma^2)] + \mathbb H_q[\beta] + \mathbb H_q[\sigma^2]. \]
The components are \[ \begin{aligned} \mathbb E_q[\ln p(y|\beta,\sigma^2)] &= -\frac{1}{2}\left\{n\ln(2\pi) + \mathbb E_q[\ln\sigma^2] + \mathbb E_q[\sigma^{-2}]\mathbb E_q[(y - X\beta)^\top(y - X\beta)]\right\} \\ &= -\frac{1}{2}\left\{n\ln(2\pi) - \ln(b_{\sigma^2}) + \psi(a_{\sigma^2}) + \frac{a_{\sigma^2}}{b_{\sigma^2}}\left[(y-X\mu_\beta)^\top(y-X\mu_\beta)+\text{tr}(X^\top X\Sigma_\beta)\right]\right\}\\ \mathbb E_q[\ln p(\beta)] &= -\frac{1}{2}\left\{d\ln(2\pi) + \ln|\Sigma_0| + \mathbb E[(\beta - \mu_0)^\top\Sigma_0^{-1}(\beta-\mu_0)]\right\}\\ &= -\frac{1}{2}\left\{d\ln(2\pi)+\ln|\Sigma_0| + (\mu_\beta-\mu_0)^\top\Sigma_0^{-1}(\mu_\beta-\mu_0) + \text{tr}(\Sigma_0^{-1}\Sigma_\beta)\right\} \\ \mathbb E_q[\ln p(\sigma^2)] &= a_0\ln(b_0)-\ln\Gamma(a_0)+(a_0+1)\mathbb E_q[\ln(\sigma^{-2})] - b_0\mathbb E_q[\sigma^{-2}] \\ &= a_0\ln(b_0)-\ln\Gamma(a_0)-(a_0+1)(\ln(b_{\sigma^2}) - \psi(a_{\sigma^2}))- b_0\frac{a_{\sigma^2}}{b_{\sigma^2}} \\ \mathbb H_q[\beta] &= \frac{1}{2}\left[d(1 + \ln(2\pi)) + \ln|\Sigma_\beta|\right]\\ \mathbb H_q[\sigma^2] &= a_{\sigma^2} + \ln(b_{\sigma^2}) + \ln\Gamma(a_{\sigma^2}) - (a_{\sigma^2} + 1)\psi(a_{\sigma^2}) \end{aligned} \]
where we have used the following identities
\[ \begin{aligned} \mathbb E_q[\beta^\top\Sigma_0^{-1}\beta] &= \mu_\beta\Sigma_0^{-1}\mu_\beta + \text{tr}\left(\Sigma_0^{-1}\Sigma\right)\quad\text{(quadratic form)}\\ \mathbb E_q[\sigma^{-2}] &= \frac{a_{\sigma^2}}{b_{\sigma^2}} \quad \text{(expectation of gamma r.v)}\\ \mathbb E_q[\ln(\sigma^{-2})] &= \ln(b_{\sigma^2}) - \psi(a_{\sigma^2}) \quad\text{(expectation of log gamma r.v.)}. \end{aligned} \]
For optimisation, the derivatives are \[ \begin{aligned} D_{\mu_\beta}\mathcal{L}(y|\theta;q) &= -\frac{1}{2}\left(\frac{a_{\sigma^2}}{b_{\sigma^2}}X^\top(y - X\mu_\beta) - \Sigma^{-1}_0(\mu_\beta-\mu_0)\right) \\ D_{\Sigma_\beta}\mathcal{L}(y|\theta;q) &= -\frac{1}{2}\left(\frac{a_{\sigma^2}}{b_{\sigma^2}}X^\top X + \Sigma_0^{-1} - \Sigma_\beta^{-1}\right)\\ D_{a_{\sigma^2}}\mathcal{L}(y|\theta;q) &= 1 + \left(a_0 + n/2 - a_{\sigma^2}\right)\psi^\prime(a_{\sigma^2}) - \frac{1}{b_{\sigma^2}}\left(b_0 + \frac{(y - X\mu_\beta)^\top(y - X\mu_\beta) + \text{tr}(X^\top X\Sigma)}{2}\right)\\ D_{b_{\sigma^2}}\mathcal{L}(y|\theta;q) &= \left(b_0 + \frac{(y-X\mu_\beta)^\top(y-X\mu_\beta) + \text{tr}(X^\top X \Sigma)}{2}\right)\frac{a_{\sigma^2}}{b_{\sigma^2}^2} - \frac{\left(a_0 + \frac{n}{2}\right)}{b_{\sigma^2}} \end{aligned} \] using \[ \begin{aligned} D_X \ln|X| &= (X^{-1})^\top\\ D_X \text{tr}(AXB) &= A^\top B^\top \end{aligned} \]
This results in fixed-point updates \[ \begin{cases} \Sigma_\beta \leftarrow \left(\frac{a_{\sigma^2}}{b_{\sigma^2}}X^\top X + \Sigma_0^{-1}\right)^{-1}\\ \mu_\beta \leftarrow \Sigma\left(\frac{a_{\sigma^2}}{b_{\sigma^2}}X^\top y + \Sigma_0^{-1}\mu_0\right)\\ a_{\sigma^2} \leftarrow a_0 + \frac{n}{2} \\ b_{\sigma^2} \leftarrow b_0 + \frac{(y-X\mu_\beta)^\top(y-X\mu_\beta) + \text{tr}(X^\top X \Sigma)}{2} \end{cases} \]
Example
An R
implementation and comparisons with results from Stan
are given below.
# Entropy for multivariate normal
mvn_ent <- function(Sigma) {
(ncol(Sigma)*(1 + log(2*pi)) + log(det(Sigma)))/2
}
# Entropy for inverse gamma
ig_ent <- function(a, b) {
a + log(b) + lgamma(a) - (a + 1)*digamma(a)
}
# VB for linear regression model
vb_lin_reg <- function(
X, y,
mu0 = rep(0, ncol(X)), Sigma0 = diag(10, ncol(X)),
a0 = 1e-2, b0 = 1e-2,
maxiter = 100, tol = 1e-5, verbose = TRUE) {
d <- ncol(X)
n <- nrow(X)
invSigma0 <- solve(Sigma0)
invSigma0_x_mu0 <- invSigma0 %*% mu0
XtX <- crossprod(X)
Xty <- crossprod(X, y)
mu <- mu0
Sigma <- Sigma0
a <- a0 + n / 2
b <- b0
lb <- as.numeric(maxiter)
i <- 0
converged <- FALSE
while(i <= maxiter & !converged) {
i <- i + 1
a_div_b <- a / b
Sigma <- solve(a_div_b * XtX + invSigma0)
mu <- Sigma %*% (a_div_b * Xty + invSigma0_x_mu0)
y_m_Xmu <- y - X %*% mu
b <- b0 + 0.5*(crossprod(y_m_Xmu) + sum(diag(Sigma %*% XtX)))[1]
# Calculate L(q)
lb[i] <- mvn_ent(Sigma) + ig_ent(a, b) +
a0 * log(b0) - lgamma(a0) - (a0 + 1) * (log(b) - digamma(a)) - b0 * a / b -
0.5*(d * log(2*pi) + log(det(Sigma0)) + crossprod(mu - mu0, Sigma0 %*% (mu - mu0)) + sum(diag(invSigma0 %*% Sigma))) -
0.5*(n * log(2*pi) - log(b) + digamma(a) + a / b * (crossprod(y_m_Xmu) + sum(diag(XtX %*% Sigma))))
if(verbose) cat(sprintf("Iteration %3d, ELBO = %5.10f\n", i, lb[i]))
if(i > 1 && abs(lb[i] - lb[i - 1]) < tol) converged <- TRUE
}
return(list(lb = lb[1:i], mu = mu, Sigma = Sigma, a = a, b = b))
}
// lin_reg
data {
int<lower=1> n;
int<lower=1> d;
real y[n];
matrix[n, d] X;
vector[d] mu0;
matrix[d, d] Sigma0;
real<lower=0> a0;
real<lower=0> b0;
}
parameters {
vector[d] beta;
real sigmasq;
}
model {
real sigma = sqrt(sigmasq);
sigmasq ~ inv_gamma(a0, b0);
beta ~ multi_normal(mu0, Sigma0);
y ~ normal(X * beta, sigma);
}
library(rstan)
library(MCMCpack)
n <- 50
beta <- c(1, 2, 3)
sigma <- 0.5
x1 <- runif(n)
x2 <- sample(c(0, 1), n, replace = T)
X <- cbind(1, x1, x2)
y <- drop(X%*%beta + rnorm(n, 0, sigma))
D <- data.frame(x1 = x1, x2 = x2, y = y)
mu0 <- rep(0, length(beta))
Sigma0 <- diag(1, length(beta))
a0 <- 1e-2
b0 <- 1e-2
vb_fit <- vb_lin_reg(X, y, mu0 = mu0, Sigma0 = Sigma0, a0 = a0, b0 = b0, tol = 1e-8)
Iteration 1, ELBO = -97.3262155814
Iteration 2, ELBO = -87.2253181794
Iteration 3, ELBO = -87.1217605969
Iteration 4, ELBO = -87.1151041597
Iteration 5, ELBO = -87.1146632789
Iteration 6, ELBO = -87.1146340212
Iteration 7, ELBO = -87.1146320794
Iteration 8, ELBO = -87.1146319505
Iteration 9, ELBO = -87.1146319419
mc_fit <- sampling(lin_reg, refresh = 0, iter = 1e4,
data = list(n = n, d = 3, X = X, y = y, mu0 = mu0, Sigma0 = Sigma0, a0 = a0, b0 = b0))
draws <- as.matrix(mc_fit)
Same data under a different prior.
mu0 <- rep(10, length(beta))
Sigma0 <- diag(0.1, length(beta))
a0 <- 1e-2
b0 <- 1e-2
vb_fit <- vb_lin_reg(X, y, mu0 = mu0, Sigma0 = Sigma0, a0 = a0, b0 = b0, tol = 1e-8)
Iteration 1, ELBO = -96.6398329245
Iteration 2, ELBO = -86.0777933979
Iteration 3, ELBO = -83.4850550434
Iteration 4, ELBO = -78.8496581897
Iteration 5, ELBO = -73.7360282502
Iteration 6, ELBO = -73.0800738730
Iteration 7, ELBO = -73.0558414617
Iteration 8, ELBO = -73.0545704858
Iteration 9, ELBO = -73.0544993784
Iteration 10, ELBO = -73.0544953843
Iteration 11, ELBO = -73.0544951599
Iteration 12, ELBO = -73.0544951473
Iteration 13, ELBO = -73.0544951465
mc_fit <- sampling(lin_reg, refresh = 0, iter = 1e4,
data = list(n = n, d = 3, X = X, y = y, mu0 = mu0, Sigma0 = Sigma0, a0 = a0, b0 = b0))
draws <- as.matrix(mc_fit)
References
Ormerod, John T. 2008. “On Semiparametric Regression and Data Mining.” Ph. D. Thesis. School of Mathematics; Statistics, The University of New South Wales.
Ormerod, John T., and Matt P. Wand. 2010. “Explaining Variational Approximations.” The American Statistician 64 (2): 140–53.