Processing math: 100%

Austin David Brown
Google scholar   |   Curriculum vitae   |   Github   |   arXiv   |   Linkedin   |   Posts   |   Categories

Batch Means for Variance Estimation in Langevin Monte Carlo


The motivation here is that it is difficult to know if the Markov chain (Xk) from an MCMC algorithm is any good. The batch means approach is based upon the functional CLT from Kipnis and Varadhan.

Theorem (Functional Markov Chain CLT). Let P=pn be an irreducible, reversible Markov chain on Rd with initial distribution ν and stationary distribution π. Let g:RdR with Varπ(g(X))2<. Suppose σ2=limnVarP(nk=1g(Xk))< Then the scaled sum 1n(SntEπ(X1))σWt in distribution where Wt is Brownian motion on [0, T] and Snt=ntk=1Xk.

The multidimensional version holds as well. Now, since Brownian motion has independent increments, we can partition [0, T] equally into intervals of size Δt with t1<t2<<tn so that σWt1,σWt2t1,σWtntn1 are i.i.d normal random variables with variance σ2Δt. Thus if the CLT holds, we can try to approximate these random variables with so called batch means.

Let N be "burn-in" length for the chain. Take m non-overlapping batch means B1+N,,Bm+N of equal size from the chain where each batch mean is Bk=1bbi=1g(Xbk+i). Then scaled appropriately 1nB1+N,,1nBm+N will be approximately i.i.d normal random variables with variance σ2b.

The issue is that we do not know σ2. We can try to estimate it with the batch means, but there is no guarantee that it will be consistent just by construction. By the ergodic theorem, the scaled mean bm¯BmEπg(X1). But the usual i.i.d. like estimate for the variance S2BM=bmm1k=0(Bk¯Bm)2. has no such guarantee. Intuitively, if we were to let the batch size get infinitely large, then the estimate should be consistent. This is indeed true and was proven by Jones et al.

Regardless, we can use the batch means to estimate the variance and calculate a confidence interval for the mean with ¯Xn±t.95,mS2BMm I would think this underestimates the interval, but I am not sure. If anyone knows the answer, please let me know.

Batch Means for a Stopping Criteria

One approach to determining when to stop an MCMC algorithm is to stop when the batch means standard error stops changing within a given tolerance. We do this in the example below.

Example: Batch Means for Ornstein–Uhlenbeck Process

The Ornstein-Uhlenbeck Markov semigroup (Pt) on Rd with the standard normal distribution as its invariant measure can be derived from the SDE dXt=Xtdt+2dWt starting at X0. This semigroup has an explicit solution Xt being N(etX0,1e2t) which we could simulate directly. Instead, we will simulate using the Euler discretization and a Metropolis-Hastings correction, which is known as MALA.

Does the CLT hold here? The chain is reversible due to the Metropolis-Hastings step and N(0,1) is the invariant measure. All compact sets are small here and I am pretty sure we can establish a drift condition. Thus, this will be sufficient for the CLT. I did not work this out though.

We simulate this in R2 with a purposefully not optimal discretization step size h=.1 and initial point X0=(3,3). The length of the chain is 104 with 102 batch means. We use Pytorch to simulate the batch means standard errors and confidence intervals.
import torch
from torch.distributions import Normal
from math import sqrt

torch.no_grad() 
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)

def OULMC(N, h, X_0):
  X = torch.zeros(N, 2)
  X[0] = X_0

  for i in range(0, N - 1):
    Z = torch.normal(torch.zeros(2), 1)
    Y = X[i] - h * X[i] + torch.sqrt(torch.tensor(2 * h)) * Z

    u = torch.zeros(1).uniform_(0, 1).item()

    R = (1/2.0 * (torch.norm(X[i], p = 2)**2 - torch.norm(Y, p = 2)**2)) + 1/(4.0 * h) * (torch.norm(Y - 1 * X[i] + h * X[i], p = 2)**2 - torch.norm(X[i] - Y + h * Y, p = 2)**2)
    a = min(1, torch.exp(R).item())
    if u <= a:
      X[i + 1] = Y
    else:
      X[i + 1] = X[i]
  return X

def bmSE(X, M):
  N = X.size(0)
  M = M # how many batches
  B = int(N/M) # batch size

  S2s = []
  mu = torch.mean(X.narrow(0, 0, B * M), 0)
  for m in range(0, M):
    B_m = torch.mean(X.narrow(0, m * B, B), 0)
    S2_m = torch.pow(B_m - mu, 2)
    S2s.append(S2_m)
  se = torch.sqrt(M * torch.mean(torch.stack(S2s), 0))
  return se

N = 10**4
h = .1
X_0 = torch.zeros(2).fill_(3)
X = OULMC(N, h, X_0)


# CI
t = 1.660234
mu = torch.mean(X, 0)
se = bmSE(X, int(sqrt(N)))
M = int(sqrt(N))
print(mu + se * t * 1/sqrt(M))
print(mu - se * t * 1/sqrt(M))
tensor([0.7188, 0.6197])
tensor([-0.6153, -0.8158])
This worked out pretty well. The confidence interval is a bit wide though. Let us try to use this approach as a stopping criterion for the MCMC algorithm. We concatenate a Markov chain in lengths of 104 and check the batch means standard error each time until the batch means standard error stops changing within the tolerance 101.
###
# Stopping criteria
###
torch.manual_seed(0)

X = OULMC(N, h, X_0)
se = bmSE(X, int(sqrt(N)))

for i in range(0, 100):
  X_new = OULMC(N, h, X[N - 1, :])
  X = torch.cat([X, X_new], 0)

  se_old = se
  se = bmSE(X, int(sqrt(N)))
  print(se)
  if torch.all(torch.abs(se - se_old) < .1):
    print("Converged at i =", i + 2)
    break

# CI
mu = torch.mean(X, 0)
se = bmSE(X, int(sqrt(N)))
M = 8 * int(sqrt(N))
t = 1.646761
print(mu + se * t * 1/sqrt(M))
print(mu - se * t * 1/sqrt(M))
# SE
tensor([2.7668, 3.3073])
tensor([2.1642, 2.6246])
tensor([2.0069, 2.3338])
tensor([1.9427, 2.2211])
tensor([1.7297, 1.9507])
tensor([1.6299, 1.6546])
tensor([1.5377, 1.6986])
Converged at i = 8

# CI
tensor([0.0838, 0.0771])
tensor([-0.0953, -0.1207])
Well, this stopped at a Markov chain of length 8104 and the confidence interval is better for sure. Although, this is quite a large chain for such a simple problem.

References.

1. Geyer, Charles J. Practical Markov Chain Monte Carlo. Statist. Sci. 7 (1992), no. 4, 473--483. doi:10.1214/ss/1177011137. https://projecteuclid.org/euclid.ss/1177011137

2. Jones, Galin L. On the Markov chain central limit theorem. Probab. Surveys 1 (2004), 299--320. doi:10.1214/154957804100000051. https://projecteuclid.org/euclid.ps/1104335301

3. Charlie Geyer's excellent notes. http://www.stat.umn.edu/geyer/f05/8931/n1998.pdf

4. Kipnis, C.; Varadhan, S. R. S. Central limit theorem for additive functionals of reversible Markov processes and applications to simple exclusions. Comm. Math. Phys. 104 (1986), no. 1, 1--19. https://projecteuclid.org/euclid.cmp/1104114929

5. Jones, G. L., Haran, M., Caffo, B. S. and Neath, R. (2006). Fixed-width output analysis for Markov chain Monte Carlo. J. Amer. Statist. Assoc. 101 1537–1547. MR2279478