Speedy splines in Stan part one
Milad Kharratzadeh provides a helpful case study on estimating splines in Stan. We can improve on his work with a few tricks to speed up the estimation process. Namely, if we decompose the spline matrix using QR decomposition, we speed up Stan’s fitting process by at least an order of magnitude.
Here’s how we can alter the code to accommodate splines with a few tricks to speed things up:
functions {
vector build_b_spline(vector t, array[] real ext_knots, int ind, int order) {
vector[size(t)] b_spline;
vector[size(t)] w1 = rep_vector(0, size(t));
vector[size(t)] w2 = rep_vector(0, size(t));
if (order == 1)
for (i in 1:size(t))
b_spline[i] = (ext_knots[ind] <= t[i]) && (t[i] < ext_knots[ind+1]);
else {
if (ext_knots[ind] != ext_knots[ind + order-1])
w1 = (t - rep_vector(ext_knots[ind], size(t))) /
(ext_knots[ind + order-1] - ext_knots[ind]);
if (ext_knots[ind + 1] != ext_knots[ind + order])
w2 = 1 - (t - rep_vector(ext_knots[ind + 1], size(t))) /
(ext_knots[ind + order] - ext_knots[ind + 1]);
b_spline = w1 .* build_b_spline(t, ext_knots, ind, order - 1) +
w2 .* build_b_spline(t, ext_knots, ind+1, order - 1);
}
return b_spline;
}
}
data {
int<lower=1> T; // number of times measured
vector[T] t; // remove and replace with position
vector[T] y; // the measurement at each time point t
int<lower=1> K; // number of knots
int<lower=1> degree; // degree of the spline
int<lower=0,upper=1> penalized; // whether to use prior for smoothing
}
transformed data {
// knots at evenly-spaced quantiles of data
array[K] real p;
for(i in 1:K) p[i] = (i - 1.0) / (K-1.0);
array[K] real k = quantile(t, p);
// build the spline matrix B
int n_basis = K + degree - 1;
matrix[n_basis, T] B;
array[2 * degree + K] real ext_knots =
append_array(append_array(rep_array(k[1], degree), k), rep_array(k[K], degree));
for (ind in 1:n_basis)
B[ind,:] = to_row_vector(build_b_spline(t, (ext_knots), ind, degree + 1));
B[K + degree - 1, T] = 1;
// QR decomposition of B, thin and scale
matrix[T, n_basis] Q_ast = qr_thin_Q(B') * sqrt(T - 1);
matrix[n_basis, n_basis] R_ast = qr_thin_R(B') / sqrt(T - 1);
matrix[n_basis, n_basis] R_ast_inverse = inverse(R_ast);
// helper stuff
vector[T] zeros_T = rep_vector(0, T);
}
parameters {
vector[n_basis] theta_raw; // coefficients on Q_ast
real<lower=0> sigma; // scale of the variation
real<lower=0> tau; // penalization on wiggles
}
transformed parameters {
vector[n_basis] theta;
if(penalized) {
theta[1] = theta_raw[1];
for(i in 2:n_basis) theta[i] = theta[i-1] + theta_raw[i] * tau;
} else {
theta = theta_raw;
}
}
model {
theta_raw ~ normal(0, 1);
tau ~ normal(0, 1);
sigma ~ exponential(1);
y ~ normal_id_glm(Q_ast, zeros_T, theta, sigma);
}
generated quantities {
vector[n_basis] beta;
beta = R_ast_inverse * theta;
vector[T] y_hat = B' * beta;
}
We can compile the model like so,
library(cmdstanr)
m = cmdstan_model('b_spline_QR.stan')
The main changes from the original case study are to decompose the spline matrix B into components:
// QR decomposition of B
matrix[T, n_basis] Q_ast;
matrix[n_basis, n_basis] R_ast;
matrix[n_basis, n_basis] R_ast_inverse;
// thin and scale the QR decomposition
Q_ast = qr_thin_Q(B') * sqrt(T - 1);
R_ast = qr_thin_R(B') / sqrt(T - 1);
R_ast_inverse = inverse(R_ast);
Then, we regress on the decomposed matrix.
Let’s simulate some data for the model and fit it:
set.seed(11)
num_knots <- 10
spline_degree <- 3
num_basis <- num_knots + spline_degree - 1
a0 <- 0.2
a <- rnorm(num_basis, 0, 1)
X <- seq(from=-10, to=10, by=.1)
knots <- unname(quantile(X,probs=seq(from=0, to=1, length.out = num_knots)))
num_data <- length(X)
B_true <- t(bs(X, df=num_basis, degree=spline_degree, intercept = TRUE))
Y_true <- as.vector(a0*X + a%*%B_true)
Y <- Y_true + rnorm(length(X), 0, 0.2)
With that data, let’s fit it in Stan using our improved model:
num_knots <- 20
spline_degree <- 3
dat <- list(
T = length(X),
K = num_knots,
penalized = 1,
degree = spline_degree,
y = Y,
t = X
)
fit <- m$sample(
data = dat,
parallel_chains = 4,
iter_warmup = 500,
iter_sampling = 500,
adapt_delta = 0.8,
max_treedepth = 9,
refresh = 0
)
## Running MCMC with 4 parallel chains...
##
## Chain 1 finished in 0.8 seconds.
## Chain 3 finished in 0.8 seconds.
## Chain 4 finished in 0.8 seconds.
## Chain 2 finished in 1.0 seconds.
##
## All 4 chains finished successfully.
## Mean chain execution time: 0.9 seconds.
## Total execution time: 1.1 seconds.
Fast, nice! Here’s a look at our fit on the original simulated data:
yhat <- fit$summary('y_hat', ~quantile(., prob = c(0.025, 0.5, 0.975)))
x <- seq(nrow(yhat))
ggplot(yhat) +
geom_ribbon(
mapping = aes(x, ymin = `2.5%`, ymax = `97.5%`),
alpha = 0.2
) +
geom_line(
mapping = aes(seq(nrow(yhat)), y = `50%`)
) +
geom_point(
mapping = aes(seq(nrow(yhat)), y = Y)
) +
labs(x = 'time', y = 'y')
2911470 @ 2023-08-31