Hand-Coded C++ Gradients for Stan
Bypassing the expression graph with make_callback_var
1 When Stan is already fast — and you need it faster
Suppose you have a Stan model that works. You have vectorized your operations, replaced loops with matrix algebra where possible, and followed the efficiency advice in the Stan User’s Guide. You are running on good hardware with enough cores. The model produces valid inferences.
And it is still too slow.
This happens. Not because Stan is slow — Stan compiles your model to C++, and the compiled code uses Eigen for linear algebra, Intel TBB for threading, and the Stan Math library for numerics. When you write dot_product(x, w) in Stan, the code that executes is already optimized C++ calling into production-grade libraries. You are not going to beat it by rewriting the arithmetic.
What you can beat is the bookkeeping. Stan’s automatic differentiation system is general-purpose: every arithmetic operation on a parameter records a node in an internal data structure called the expression graph. For most models this overhead is negligible. But when a function performs millions of repetitive operations — looping over hundreds of thousands of observations, gathering parameters by index, computing the same formula at each step — the per-node bookkeeping can dominate the actual arithmetic. The math is fast. The tracking machinery around it is not.
This tutorial teaches you how to eliminate that overhead for targeted sections of your model. The idea is not to replace Stan’s C++ with different C++ — the arithmetic stays the same. It is to replace Stan’s general-purpose gradient tracking with a hand-crafted gradient computation that exploits the specific structure of your function. You tell Stan: “I know the gradient of this function analytically — skip the graph, here is the answer.”
The technique uses a single entry point in the Stan Math library called make_callback_var (defined in callback_vari.hpp in the Stan Math source). It is not documented in the Stan User’s Guide. It is not complicated. But getting it right requires understanding a few ideas clearly, and the existing documentation consists of source-code comments and scattered forum posts.
We will build up from the simplest possible example — a weighted sum — to indexed gather operations, and then to parallel execution. By the end, you will understand the pattern well enough to apply it to your own models.
1.1 What this tutorial assumes
You are comfortable writing Stan models. You have a working CmdStan, CmdStanR, or CmdStanPy installation. You have encountered the idea of automatic differentiation and know, at least roughly, that Stan computes gradients for you. You do not need to have written C++ before, though you will be writing some here.
This tutorial is not a general introduction to automatic differentiation. It is not about replacing Stan’s autodiff everywhere. It is about replacing it in specific, expensive functions that you have already identified as bottlenecks — through profiling, through timing, or through the experience of watching a sampler crawl.
1.2 Why doesn’t Stan already do this?
If hand-coding the gradient can be so much faster, a natural question: why doesn’t Stan do it automatically?
The short answer is that Stan does — for its own built-in functions. When you call normal_lpdf or log_sum_exp, you are not hitting the general-purpose autodiff system. Stan’s developers have written hand-coded gradients for hundreds of built-in functions, using exactly the technique described in this tutorial. Those functions create a single node in the expression graph and compute their gradients analytically in a callback. This is why Stan’s built-in functions are fast.
What Stan cannot do is write those hand-coded gradients for your functions. When you write a custom computation in the Stan language — a loop over observations, a series of indexed lookups, a particular formula — Stan’s only option is the general-purpose approach: record every operation as a node, build the graph, and traverse it later. Stan does not know the mathematical structure of your computation. You do.
Could a sufficiently smart compiler figure it out? In principle, partially. A compiler could recognize common patterns — “this is a dot product over gathered indices” or “this is a log-sum-exp reduction” — and replace them with optimized, hand-coded gradient implementations. Some of this happens already: Stan’s compiler performs expression optimization and can fuse certain operations. But the space of possible user-written functions is vast, and recognizing that a particular 50-line loop is mathematically equivalent to a function with a known analytic gradient is, in the general case, an unsolved problem in compiler design.
So the situation is: Stan gives you the general-purpose system (correct for any differentiable function, no calculus required from the user) and an escape hatch (make_callback_var) for the cases where you know enough about your function’s math to do better. The escape hatch exists because Stan’s developers use it themselves — they are just giving you access to the same tool. (The Stan Math contributor guide on adding new gradients describes this pattern from the developer’s perspective.)
1.3 When this technique applies
This technique requires that you can write down the gradient of your function analytically — that is, as an explicit formula involving the inputs and any intermediates you computed during the forward pass. You will implement that formula in C++ and hand it to Stan. If you cannot express the gradient as a formula, you cannot use this approach.
The good news is that most functions encountered in statistical modeling have analytic gradients. If your function is built from standard operations — addition, multiplication, exp, log, pow, dot products, matrix multiplies, sums, reductions — then its gradient can be derived by applying the chain rule (remember that term from high school Calculus?) step by step. That is exactly what Stan’s autodiff does automatically; you are simply doing it yourself on paper.
A practical test: can you write the function as a composition of operations that each have a known derivative? If so, the chain rule gives you the gradient of the whole function. For example:
- \(f(\mathbf{x}) = \sum_i x_i w_i\) — each term’s derivative with respect to \(x_i\) is \(w_i\). Analytic gradient: yes.
- \(f(\mathbf{x}) = \log \sum_i \exp(x_i)\) — the derivative with respect to \(x_i\) is the softmax weight \(\exp(x_i) / \sum_j \exp(x_j)\). Analytic gradient: yes.
- A function involving
ifstatements on parameter values (like a piecewise linear function) — the gradient exists everywhere except at the breakpoints, and is a simple formula on each piece. Analytic gradient: yes, in practice.
Functions where this technique does not readily apply:
- Functions defined by iterative algorithms where the number of iterations depends on the input (e.g., root-finding by Newton’s method, optimization loops). The gradient of the output with respect to the input exists in principle (via the implicit function theorem), but it is not a simple formula you can write in a callback.
- Functions involving discrete combinatorial operations (sorting, argmax) where the output changes discontinuously with the input.
If you can derive the gradient on paper — even if the derivation takes some work — this technique applies. If the derivation requires specialized mathematical machinery (implicit differentiation, adjoint methods for ODEs), you are likely better served by Stan’s built-in support for those cases (algebra_solver, ode_*), which already implements the specialized math.
2 How Stan computes gradients
To understand what we are replacing, we need to understand what Stan does by default. This requires a few ideas that we will build up one at a time.
2.1 Two kinds of numbers
Stan works with two kinds of numbers internally. The first is the ordinary floating-point number — a double in C++. When you declare something in a data block or transformed data block, Stan stores it as a double. It is just a number. Nothing special happens when you do arithmetic with it.
The second is Stan’s gradient-tracking type, called var (short for “variable”). When you declare a parameter — anything in the parameters block, or derived from parameters in transformed parameters or model — Stan wraps it in a var. A var holds a double value inside, but it also carries machinery for tracking how that value was computed, so that Stan can later figure out how the final answer (the log-posterior) depends on it.
The distinction matters because the entire technique in this tutorial rests on it: we strip the var wrappers off, do the arithmetic on plain double values (fast, no tracking), and then tell Stan the gradient ourselves.
2.2 The expression graph
When your model evaluates an expression involving var values — that is, parameters — Stan does not just compute the answer. It also builds a graph, called the expression graph (sometimes called the “tape”), that records every operation and its inputs.
Consider a simple expression where \(a\), \(b\), and \(c\) are parameters:
\[ y = a \cdot b + c \]
Stan computes the value of \(y\), but it also builds something like this:
[a] [b] [c]
\ / |
[×] |
\ /
[+] → y
Each box is a node. Each node remembers its inputs and a rule for computing how sensitive the output is to changes in that node’s result. (We will make this precise shortly.)
For small expressions, the overhead of building this graph is negligible. The trouble starts when the graph grows large.
2.3 When the graph becomes the bottleneck
Imagine a model with 300,000 observations. Each observation requires gathering a handful of parameters by index, computing a log-likelihood contribution, and adding it to the target. If each iteration of that loop creates 30 nodes, the graph contains 9 million nodes — each one allocated, linked, and later traversed during the gradient computation.
The actual floating-point arithmetic might take microseconds per observation. The graph bookkeeping can take significantly longer.
This is the situation where hand-coded C++ helps. Instead of 9 million nodes, you create one. That single node stores a callback function that you write — a function that knows how to compute the gradients for the entire operation. Stan calls your callback during the gradient computation. The arithmetic is the same. The bookkeeping vanishes.
2.4 The forward and reverse passes
Stan’s gradient computation happens in two phases.
The forward pass evaluates the model to get a numeric answer — the log-posterior density at the current parameter values. This is the number that appears as lp__ in Stan’s output. During the forward pass, the expression graph is built.
The reverse pass walks the graph backward from the output to the inputs. At each node, it figures out how much the final output would change if that node’s value changed by a tiny amount. This is the information the sampler needs: NUTS explores parameter space by following the gradient of the log-posterior, the way a hiker might follow the slope of a hillside. Without knowing which direction is “uphill” — which parameter adjustments increase the log-posterior and by how much — the sampler cannot take efficient steps. The gradient tells it where to go.
The quantity computed at each node during the reverse pass is called the adjoint of that node.
Let us make this concrete. Suppose the log-posterior is currently \(L = 10.0\), and somewhere in the computation there is an intermediate variable \(t = 3.0\) (a var in Stan’s type system). If nudging \(t\) from \(3.0\) to \(3.001\) would change \(L\) from \(10.0\) to \(10.005\), then the adjoint of \(t\) is \(0.005 / 0.001 = 5\). In calculus notation, we write this as a partial derivative — the rate of change of \(L\) with respect to \(t\) while holding everything else fixed:
\[ \frac{\partial L}{\partial t} = 5 \]
The symbol \(\partial L / \partial t\) is read “the partial derivative of \(L\) with respect to \(t\).” It answers a specific question: if I nudge \(t\) by a tiny amount, how much does \(L\) change, per unit of nudge? The adjoint is just a name for this quantity in the context of reverse-mode autodiff.
2.5 The chain rule: how adjoints propagate
The reverse pass computes adjoints by working backward through the graph, applying a rule from calculus called the chain rule. The chain rule says: if \(L\) depends on \(t\), and \(t\) depends on \(a\), then the rate at which \(L\) changes with \(a\) equals the rate at which \(L\) changes with \(t\), multiplied by the rate at which \(t\) changes with \(a\):
\[ \frac{\partial L}{\partial a} = \frac{\partial L}{\partial t} \cdot \frac{\partial t}{\partial a} \]
A concrete example makes this mechanical. Take \(y = a \cdot b + c\) with \(a = 3\), \(b = 2\), \(c = 1\), so \(y = 7\). Suppose the rest of the model is such that \(\partial L / \partial y = 4\) — that is, the upstream adjoint arriving at our expression is 4.
The intermediate product is \(t_1 = a \cdot b = 6\). The addition \(y = t_1 + c\) has the property that nudging either input by \(\epsilon\) nudges \(y\) by \(\epsilon\), so \(\partial y / \partial t_1 = 1\) and \(\partial y / \partial c = 1\).
Now apply the chain rule:
- Adjoint of \(t_1\): \(\frac{\partial L}{\partial t_1} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial t_1} = 4 \cdot 1 = 4\)
- Adjoint of \(c\): \(\frac{\partial L}{\partial c} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial c} = 4 \cdot 1 = 4\)
For the multiplication \(t_1 = a \cdot b\): nudging \(a\) from 3 to 3.001 changes \(t_1\) from 6 to 6.002, so \(\partial t_1 / \partial a = b = 2\). Similarly, \(\partial t_1 / \partial b = a = 3\). The chain rule gives:
- Adjoint of \(a\): \(\frac{\partial L}{\partial a} = \frac{\partial L}{\partial t_1} \cdot \frac{\partial t_1}{\partial a} = 4 \cdot 2 = 8\)
- Adjoint of \(b\): \(\frac{\partial L}{\partial b} = \frac{\partial L}{\partial t_1} \cdot \frac{\partial t_1}{\partial b} = 4 \cdot 3 = 12\)
You can verify: nudging \(a\) from 3 to 3.001 changes \(y\) from 7 to 7.002, and if \(\partial L / \partial y = 4\), then \(L\) changes by \(0.002 \times 4 = 0.008\), giving a rate of \(0.008 / 0.001 = 8\). It checks out.
Stan does this computation automatically for every node in the graph. When you write custom C++, you do it yourself — once, for the entire function — and hand the results directly to the inputs. The arithmetic is the same. You are just doing it by hand instead of asking Stan to build a graph and traverse it.
3 A weighted sum
We begin with the simplest possible case: a function that takes two vectors and returns their weighted sum (a dot product).
\[ f(\mathbf{x}, \mathbf{w}) = \sum_{i=1}^{N} x_i \, w_i \]
In Stan, you would write dot_product(x, w) or sum(x .* w) and the built-in autodiff would handle the gradient. We are reimplementing it in C++ not because it needs reimplementing, but because the pattern is clear enough to teach the mechanics without distraction.
3.1 The Stan declaration
In your .stan file, declare the function with no body (see the Stan Reference Manual on user-defined functions for the full syntax):
functions {
real weighted_sum(vector x, vector w);
}The missing body tells the Stan compiler to look for a C++ implementation in a user-provided header file. You will supply this header at compile time. (The CmdStan guide on external C++ covers the mechanics of how this linkage works.)
3.2 The C++ implementation
Create a C++ header file, named something informative, here we’ll call it user_header.hpp. This is the complete file for our first example:
#ifndef USER_HEADER_HPP
#define USER_HEADER_HPP
#include <stan/math.hpp>
#include <ostream>
namespace stan_model {
template <typename EigVec1, typename EigVec2>
auto weighted_sum(const EigVec1& x,
const EigVec2& w,
std::ostream* pstream__) {
// Extract the scalar types: var (parameter) or double (data)
using T_x = stan::scalar_type_t<EigVec1>;
using T_w = stan::scalar_type_t<EigVec2>;
// If both inputs are data (double), no gradient needed
if constexpr (!stan::is_var_v<T_x> && !stan::is_var_v<T_w>) {
const auto& x_val = stan::math::value_of(x);
const auto& w_val = stan::math::value_of(w);
return x_val.dot(w_val);
} else {
// --- Forward pass: compute on plain doubles ---
const auto& x_val = stan::math::value_of(x);
const auto& w_val = stan::math::value_of(w);
double result = x_val.dot(w_val);
// --- Keep var inputs alive until the reverse pass ---
auto x_arena = stan::math::to_arena(x);
auto w_arena = stan::math::to_arena(w);
// --- Return a var with a custom reverse pass ---
return stan::math::make_callback_var(
result,
[x_arena, w_arena](auto& ret) {
double upstream = ret.adj();
if constexpr (stan::is_var_v<T_x>) {
x_arena.adj().array() += upstream * stan::math::value_of(w_arena).array();
}
if constexpr (stan::is_var_v<T_w>) {
w_arena.adj().array() += upstream * stan::math::value_of(x_arena).array();
}
}
);
}
}
} // namespace stan_model
#endifThat is a lot of code for a dot product. Let us walk through it piece by piece.
3.2.1 The function signature
template <typename EigVec1, typename EigVec2>
auto weighted_sum(const EigVec1& x,
const EigVec2& w,
std::ostream* pstream__) {Stan’s compiler generates calls to your C++ function using template types. A vector argument in Stan might arrive as an Eigen::Matrix<var, -1, 1> (when the vector contains parameters) or an Eigen::Matrix<double, -1, 1> (when it contains data). The templates let a single function definition handle both cases — the compiler generates the appropriate version for each combination of input types.
The pstream__ argument is Stan’s output stream. Every user-defined function receives it as the last argument. You will rarely use it, but it must be there. (The CmdStan external C++ guide documents the full signature conventions, including the pstream__ parameter and the required namespace.)
The return type is auto because it depends on the inputs: double when all inputs are data, var when any input is a parameter.
3.2.2 The data-only shortcut
if constexpr (!stan::is_var_v<T_x> && !stan::is_var_v<T_w>) {
const auto& x_val = stan::math::value_of(x);
const auto& w_val = stan::math::value_of(w);
return x_val.dot(w_val);
}This is the first thing to understand, and the most common source of confusion. When Stan calls your function with data-only arguments — in the transformed data block, for example — the inputs are plain double values, not var. There are no parameters, no gradients to compute, and no expression graph. You just compute the answer and return a double.
The if constexpr is a compile-time check. It asks: “are any of the inputs var?” If not, the compiler generates a simple version of the function that does plain arithmetic. If so, it generates the version with gradient tracking. Both versions exist in the compiled code; the compiler discards the one that does not apply for each call site. It costs nothing at runtime.
value_of() extracts the numeric value from whatever type it receives. On a double, it returns the double unchanged. On a var, it returns the double stored inside. We use it uniformly so the same code works in both branches.
3.2.3 The forward pass
const auto& x_val = stan::math::value_of(x);
const auto& w_val = stan::math::value_of(w);
double result = x_val.dot(w_val);We strip the var wrappers and compute on plain double values. No nodes are created. No graph is built. This is just arithmetic.
3.2.4 Keeping inputs alive
auto x_arena = stan::math::to_arena(x);
auto w_arena = stan::math::to_arena(w);The reverse pass does not happen now. It happens later — after the entire forward pass of the model is complete. By that time, the local variables x and w may have gone out of scope, and their memory may have been reused.
to_arena() copies the data into Stan’s arena memory — a fast allocator that is freed all at once after the reverse pass completes. This guarantees that x_arena and w_arena are still valid when the callback fires.
3.2.5 The callback
return stan::math::make_callback_var(
result,
[x_arena, w_arena](auto& ret) {
double upstream = ret.adj();
if constexpr (stan::is_var_v<T_x>) {
x_arena.adj().array()
+= upstream * stan::math::value_of(w_arena).array();
}
if constexpr (stan::is_var_v<T_w>) {
w_arena.adj().array()
+= upstream * stan::math::value_of(x_arena).array();
}
}
);make_callback_var does two things:
- It creates a
varwhose value isresult. - It registers a callback (a lambda function) that Stan will call during the reverse pass.
When the reverse pass reaches this node, Stan calls the lambda. The argument ret is the var that was returned — our single node in the graph. ret.adj() is the upstream adjoint: how much the log-posterior would change per unit change in our function’s return value.
Now we need the gradient of our function — the partial derivative of the output with respect to each input element. Our function is \(f(\mathbf{x}, \mathbf{w}) = \sum_i x_i w_i\). To find how the output changes when we nudge one input, say \(x_1\), hold everything else fixed: if \(\mathbf{x} = (3, 5)\) and \(\mathbf{w} = (2, 4)\), then \(f = 3 \cdot 2 + 5 \cdot 4 = 26\). Nudging \(x_1\) from 3 to 3.001 gives \(f = 3.001 \cdot 2 + 5 \cdot 4 = 26.002\). The rate of change is \(0.002 / 0.001 = 2\), which is \(w_1\). In general:
\[ \frac{\partial f}{\partial x_i} = w_i \qquad\qquad \frac{\partial f}{\partial w_i} = x_i \]
The chain rule then gives us the adjoint of each input — multiply the upstream adjoint (how much \(L\) changes per unit change in \(f\)) by the local derivative (how much \(f\) changes per unit change in the input):
\[ \text{adj}(x_i) \mathrel{+}= \frac{\partial L}{\partial f} \cdot w_i \qquad\qquad \text{adj}(w_i) \mathrel{+}= \frac{\partial L}{\partial f} \cdot x_i \]
That is exactly what the code does. The += is important — this input might be used elsewhere in the model, so we accumulate into the adjoint rather than overwriting it.
The if constexpr guards ensure we only touch adjoints of var inputs. If w is data (a double vector), it has no .adj() to write to — data does not participate in gradients — and attempting to access it would fail at compile time.
3.3 Compiling and running
model <- cmdstan_model(
'model.stan',
stanc_options = list("allow-undefined" = TRUE),
user_header = normalizePath('user_header.hpp')
)model = CmdStanModel(
stan_file='model.stan',
stanc_options={'allow-undefined': True},
user_header='user_header.hpp'
)Add to your make/local file:
STANCFLAGS += --allow-undefined
USER_HEADER = /full/path/to/user_header.hppThe --allow-undefined flag tells the Stan compiler to accept function declarations without bodies. The user_header option tells CmdStan to include your C++ file during compilation (see the CmdStanR and CmdStanPy documentation for the interface-specific arguments). Stan expects a single header file — but if your implementation grows large enough to split across files, you can use one main header that #includes the others:
// user_header.hpp
#ifndef USER_HEADER_HPP
#define USER_HEADER_HPP
#include "weighted_sum.hpp"
#include "gathered_lse.hpp"
#include "my_likelihood.hpp"
#endifEach included file contains one or more function implementations. Stan only sees the single user_header.hpp; you organize the internals however you like.
3.4 The recipe
Every custom C++ function for Stan follows the same structure:
- Signature: template over input types, return
auto, acceptpstream__as the last argument. - Data shortcut: if no input is a
var, compute and return adouble. - Forward pass: strip inputs to
doublewithvalue_of(), compute the answer. - Arena copy: save
varinputs withto_arena(). - Callback: return
make_callback_var(value, lambda), where the lambda reads the upstream adjoint and distributes it to each input’s.adj()using the chain rule.
This recipe does not change. The only thing that varies between functions is the arithmetic in steps 3 and 5.
4 A log-sum-exp gather
The weighted sum was clean because the gradient was trivial. Now we tackle a function with nonlinear operations and index arrays — the kind of pattern that actually appears in real models.
Consider a model where you need the log-sum-exp of a subset of parameters, selected by an index array:
\[ f(\mathbf{x}, \text{idx}) = \log \sum_{i=1}^{K} \exp\!\big(x_{\text{idx}[i]}\big) \]
This is common in mixture models, hierarchical models with group-level indexing, and any model where observations map to subsets of parameters.
4.1 The Stan declaration
functions {
real gathered_lse(vector x, array[] int idx);
}The vector x holds the full parameter vector. The integer array idx selects which elements to include, using Stan’s 1-based indexing.
4.2 The math
The forward pass computes the log-sum-exp. For numerical stability, we use the standard trick of subtracting the maximum:
\[ m = \max_i\, x_{\text{idx}[i]} \qquad f = m + \log \sum_{i=1}^{K} \exp\!\big(x_{\text{idx}[i]} - m\big) \]
For the reverse pass, we need the partial derivative of the log-sum-exp with respect to each gathered element — how much does the output change when we nudge one input?
A concrete example helps. Suppose \(\mathbf{x}_{\text{idx}} = (1, 3)\). Then \(f = \log(e^1 + e^3) = \log(2.718 + 20.086) = \log(22.804) \approx 3.127\). If we nudge the second element from 3 to 3.001: \(f' = \log(e^1 + e^{3.001}) \approx 3.12788\). The rate of change is about \(0.88\), which is \(e^3 / (e^1 + e^3)\) — the softmax weight of the second element.
In general, the partial derivative turns out to be the softmax — the exponential of each element divided by the sum of all exponentials:
\[ \frac{\partial f}{\partial x_{\text{idx}[i]}} = \frac{\exp(x_{\text{idx}[i]})}{\sum_{j=1}^{K} \exp(x_{\text{idx}[j]})} = \text{softmax}(\mathbf{x}_{\text{idx}})_i \]
We will need these softmax weights during the reverse pass, so we compute and store them during the forward pass.
4.3 The C++ implementation
template <typename EigVec>
auto gathered_lse(const EigVec& x,
const std::vector<int>& idx,
std::ostream* pstream__) {
using T_x = stan::scalar_type_t<EigVec>;
const int K = idx.size();
// Extract doubles
const auto& x_val = stan::math::value_of(x);
// Forward pass: numerically stable log-sum-exp
double max_val = -std::numeric_limits<double>::infinity();
for (int i = 0; i < K; ++i) {
max_val = std::max(max_val, x_val[idx[i] - 1]); // 1-based to 0-based
}
Eigen::VectorXd weights(K);
double sum_exp = 0.0;
for (int i = 0; i < K; ++i) {
weights[i] = std::exp(x_val[idx[i] - 1] - max_val);
sum_exp += weights[i];
}
double result = max_val + std::log(sum_exp);
// Softmax weights (needed for reverse pass)
weights /= sum_exp;
if constexpr (!stan::is_var_v<T_x>) {
return result;
} else {
auto x_arena = stan::math::to_arena(x);
// Move weights to arena so the callback can read them
auto weights_arena = stan::math::to_arena(weights);
return stan::math::make_callback_var(
result,
[x_arena, weights_arena, idx, K](auto& ret) {
double upstream = ret.adj();
for (int i = 0; i < K; ++i) {
x_arena.adj()[idx[i] - 1] += upstream * weights_arena[i];
}
}
);
}
}The structure is the same recipe from the weighted sum. The only differences are the arithmetic (log-sum-exp instead of dot product) and the gradient (softmax weights instead of the other vector’s values). Let us look at the parts that are new.
4.4 Index translation
Stan uses 1-based indexing. C++ uses 0-based. Every access into a C++ Eigen vector with a Stan index requires subtracting 1:
x_val[idx[i] - 1]Getting this wrong produces silent garbage — the function reads from the wrong memory location, returns a plausible-looking number, and the sampler wanders off without complaint. Check your indices carefully.
4.5 Repeated indices and the += rule
Suppose idx = {3, 3, 5}. The parameter x[3] appears twice. During the reverse pass, it must receive the adjoint contributions from both appearances:
x_arena.adj()[idx[0] - 1] += upstream * weights_arena[0]; // x[3]
x_arena.adj()[idx[1] - 1] += upstream * weights_arena[1]; // x[3] againBoth lines write to x_arena.adj()[2] — and the += ensures the contributions accumulate. If you used = instead, the second write would overwrite the first. The gradient would be wrong, the sampler would silently produce biased results, and nothing would warn you.
This is the most common bug in hand-coded gradients with index arrays. Use += for adjoints. Always.
4.6 Storing intermediates
The softmax weights are computed during the forward pass and consumed during the reverse pass. These two events happen at different times — the forward pass runs now, and the reverse pass runs after the entire model has been evaluated.
We solve this by storing the weights in arena memory:
auto weights_arena = stan::math::to_arena(weights);The lambda captures weights_arena by value (it captures the arena pointer, which is cheap). When the reverse pass eventually calls the lambda, the weights are still there.
This is a general principle: anything the reverse pass needs that is not already stored in a var must be copied to the arena. Intermediate doubles, precomputed gradient factors, scratch buffers — all of it. (The Stan Math source warns explicitly: “All captured values must be trivially destructible or they will leak memory. to_arena() function can be used to ensure that.” The Stan Math common pitfalls guide elaborates on these lifetime issues.)
4.7 Defensive checks
Stan models that segfault are painful to debug. A few cheap checks at the top of your function can save hours:
// Check that indices are within bounds
for (int i = 0; i < K; ++i) {
if (idx[i] < 1 || idx[i] > x_val.size()) {
throw std::out_of_range("gathered_lse: index out of bounds");
}
}Stan also provides stan::math::check_* functions (check_matching_sizes, check_positive, etc.) for common validations. These produce informative error messages that include the function name and the offending value.
5 Adding parallelism with reduce_sum
You have a custom C++ function that eliminates the expression graph overhead. The model is faster. But the function runs on a single core, and the computation is embarrassingly parallel — each observation or group contributes independently to the log-posterior.
The instinct might be to reach for Intel TBB directly, writing parallel_reduce bodies with thread-local adjoint buffers. That works, and we will cover it in Section 6. But for most problems, there is a much simpler approach: let Stan handle the threading.
Stan’s reduce_sum function partitions a computation across threads. It handles work distribution, thread management, and — critically — the safe aggregation of adjoints across threads. You do not touch TBB. You do not manage thread-local buffers. You write a single-threaded partial sum function, and reduce_sum does the rest. (The Stan case study on reduce_sum walks through a minimal example.)
The key insight is that reduce_sum and make_callback_var solve different problems. reduce_sum solves parallelism. make_callback_var solves expression graph overhead. You can — and often should — use both.
5.1 The pattern
In your .stan file:
functions {
real my_fast_lp(vector theta, array[] int obs_idx, ...);
real partial_sum_lpmf(array[] int slice,
int start, int end,
vector theta, ...) {
return my_fast_lp(theta, slice, ...);
}
}
model {
target += reduce_sum(partial_sum_lpmf, obs_group_ids,
grainsize, theta, ...);
}The function my_fast_lp is your custom C++ implementation — it uses make_callback_var to bypass the expression graph, exactly as in the previous two examples. The function partial_sum_lpmf is a thin Stan wrapper that reduce_sum calls for each slice of work.
Stan partitions obs_group_ids into slices, sends each slice to a thread, and combines the results. Your C++ function does the heavy arithmetic on each slice without building any graph. The combination of both techniques gives you parallel execution and minimal graph overhead.
5.2 Why reduce_sum handles the hard parts
When multiple threads modify the same parameter’s adjoint simultaneously, you get a race condition — a class of bug where the result depends on the unpredictable timing of threads. Thread A reads the adjoint, adds its contribution, and writes it back — but Thread B read the same value before A wrote, so B’s write overwrites A’s contribution. The adjoint ends up wrong, and the error changes from run to run.
reduce_sum avoids this by giving each thread its own copy of the parameters’ adjoint storage. When all threads finish, it merges the results. You do not need to think about any of this — it is handled internally.
If you wrote TBB code yourself, you would need to implement this merging manually. For most models, there is no reason to.
5.3 Grainsize tuning
The grainsize parameter controls how reduce_sum partitions the work. Each slice contains at least grainsize elements. The tradeoffs:
- Too small: thread overhead dominates. Each slice does little work, but the cost of dispatching and merging is nontrivial.
- Too large: threads sit idle. If you have 8 cores and a grainsize that creates only 3 slices, 5 cores do nothing.
A good starting point: divide the number of observations (or whatever natural unit you are slicing) by the number of threads per chain. If you have 10,000 observations and 4 threads, start with grainsize = 2500. The reasoning: if the workload is roughly balanced across observations, this gives each thread exactly one slice, which minimizes the overhead of dispatching and combining results.
Then halve the grainsize and run again. Halve it once more. Compare the wall-clock times. If the observations are well balanced — roughly equal work per observation — the first value (one slice per thread) is usually fastest. If the workload is unbalanced — some observations are much more expensive than others — smaller grainsizes let the scheduler even things out by giving threads that finish early another slice to work on.
Setting grainsize = 1 tells Stan to tune automatically at runtime. This is a reasonable fallback if you do not want to experiment, but in practice the manual approach above often wins, especially once the expression graph overhead is gone and the per-observation cost is predictable.
To benchmark grainsize, fix the random seed (seed), the initialization seed (init), and run a fixed number of iterations (e.g., iter_warmup = 200, iter_sampling = 200). This gives you comparable wall-clock times across runs without noise from different trajectories.
5.4 Compiling with threading
model <- cmdstan_model(
'model.stan',
cpp_options = list(stan_threads = TRUE),
stanc_options = list("allow-undefined" = TRUE),
user_header = normalizePath('user_header.hpp')
)
fit <- model$sample(
data = stan_data,
threads_per_chain = 4
)model = CmdStanModel(
stan_file='model.stan',
cpp_options={'stan_threads': True},
stanc_options={'allow-undefined': True},
user_header='user_header.hpp'
)
fit = model.sample(
data=stan_data,
threads_per_chain=4
)STANCFLAGS += --allow-undefined
USER_HEADER = /full/path/to/user_header.hpp
STAN_THREADS = trueThen set the environment variable before running:
export STAN_NUM_THREADS=4
./model sample data file=data.json6 Raw TBB parallelism
In rare cases, reduce_sum is not the right tool. Perhaps the computation does not decompose into independent additive terms. Perhaps you need fine-grained control over how the forward and reverse passes are parallelized separately. Perhaps the partial sums share expensive intermediate computations that you want to cache.
In these cases, you can use Intel TBB directly from within your make_callback_var callback. This is harder to get right and harder to maintain. Try reduce_sum first. Come here only if you have a specific reason. (For more discussion of advanced patterns, including vector-valued returns, see this Discourse thread.)
6.2 The parallel_reduce pattern
TBB’s parallel_reduce is designed for exactly this pattern. You define a “body” class with three capabilities: process a range of work, accumulate results into local storage, and merge two bodies together.
The skeleton looks like this:
struct ReverseBody {
// Shared (read-only) data
const Eigen::VectorXd& weights;
const std::vector<int>& idx;
double upstream;
// Thread-local adjoint buffer
Eigen::VectorXd local_adj;
ReverseBody(/* shared data */, int n_params)
: /* init shared refs */,
local_adj(Eigen::VectorXd::Zero(n_params)) {}
// Splitting constructor — TBB calls this to create a new thread's body
ReverseBody(ReverseBody& other, tbb::split)
: /* copy shared refs from other */,
local_adj(Eigen::VectorXd::Zero(other.local_adj.size())) {}
// Process a range of work
void operator()(const tbb::blocked_range<int>& range) {
for (int i = range.begin(); i < range.end(); ++i) {
local_adj[idx[i] - 1] += upstream * weights[i];
}
}
// Merge another body's results into this one
void join(const ReverseBody& other) {
local_adj += other.local_adj;
}
};Each thread gets its own ReverseBody with a zeroed-out local_adj buffer. It writes only to its own buffer — no conflicts. After all threads finish, TBB calls join() to merge the buffers pairwise until a single body holds all the contributions. You then add those into the var inputs’ actual adjoints:
x_arena.adj() += body.local_adj;This final step is single-threaded, which is fine — it is a single vector addition.
6.4 When to use raw TBB
Use raw TBB only when:
- The computation cannot be expressed as a sum of independent terms (so
reduce_sumdoes not apply). - You need to parallelize the reverse pass differently from the forward pass.
- You have verified, with profiling, that
reduce_sumleaves performance on the table for your specific case.
For the vast majority of models, the combination of make_callback_var (for graph elimination) and reduce_sum (for parallelism) is sufficient.
7 A real-world case study
The techniques in this tutorial were developed for a model of global football — estimating player abilities across more than 85 leagues and a decade of seasons.
7.1 The model
The model is a large Bayesian hierarchical model with several likelihoods, each contributing to the joint posterior. The data includes match-level and player-level observations across a decade of competition. The two most expensive likelihoods — and the ones we optimized — both follow the gather-scatter pattern that makes this technique effective.
Match segments: Nearly 1.5 million segments, where a segment is a stretch of play between match events. Each segment gathers multi-correlated ability parameters for every player on the pitch — all 22 players across both teams — plus league-level and match-level covariates, and computes a segment-level log-likelihood. The Stan implementation involved two stages: building the segment data structures (gathering parameters by index, computing team-level quantities) and evaluating the log-probability. Both stages create expression graph nodes — hundreds per segment, hundreds of millions in total.
Manager decisions: An observation at every match start and every intra-match decision, modeling the sequential decisions a manager makes — think player selection or usage — as a function of player abilities and match context. Each decision gathers parameters from the available pool and computes a probability. The same gather-scatter structure: loop over decisions, gather indexed parameters, evaluate, accumulate.
The model includes several other likelihoods, but these two dominated the wall time. The Stan code was already well optimized — both likelihoods were parallelized with reduce_sum, the code was vectorized where possible, and unnecessary allocations had been eliminated. The model was, in other words, already at the point this tutorial begins: Stan’s standard optimizations had been applied, and the bottleneck was the expression graph itself. With 1.5 million segments each creating hundreds of nodes, the graph grew to hundreds of millions of nodes per gradient evaluation.
Running this already-optimized model in the cloud on Intel Granite Rapids hardware — 72 physical cores per machine, one machine per chain, four chains — the wall time could exceed four days. The model was under active development (priors, covariates, model structure were all still evolving), but these two likelihoods had stabilized. Waiting four days for feedback on a structural change to other model components slowed development.
7.2 The optimization — in stages
The optimization was done incrementally, one likelihood at a time, with profiling at each stage. This is the approach I recommend: do not rewrite everything at once. Start with the most expensive component, verify it, then move to the next.
Stage 1: Match segments. Both stages of the Stan implementation — gathering parameters and evaluating the log-probability — were replaced with a single fused C++ function using make_callback_var with TBB parallelism. The forward pass gathers all player and match-level parameters on plain double values, computes the segment log-likelihood, and stores intermediates in a per-segment cache. The reverse pass computes analytic gradients and scatters them back to the parameter adjoints via thread-local buffers.
The pattern, in pseudocode:
Forward (per segment s):
gather: player ability parameters, league/match covariates
compute: segment log-likelihood on doubles
cache: intermediates for reverse pass
Reverse (per segment s):
compute: analytic gradient w.r.t. each gathered parameter
scatter: adj_theta[player_idx[s]] += gradient * upstream adjoint
Stage 2: Manager decisions. The same pattern was applied — gather player parameters for each decision point, compute the log-probability on doubles, scatter analytic gradients back. A second .hpp file was created, and a main custom_functions.hpp header was used to #include both implementations.
7.3 Profiling results
All three versions — pure Stan, custom C++ for match segments only, and custom C++ for both likelihoods — were profiled on the same data (85 leagues) with identical seeds, using 24 threads on a single chain, with 5 warmup + 5 sampling iterations, and seeds set for inits and Stan. The proportion of time each likelihood consumes is more informative than raw seconds here, since the original model’s four-day wall time makes the absolute scale clear.
| Likelihood | Pure Stan | Custom both | Speedup | ||
|---|---|---|---|---|---|
| Match segments | 19.9 s | 64.4% | 1.9 s | 34.7% | 10.6x |
| Manager decisions | 8.4 s | 27.2% | 0.9 s | 16.9% | 9.2x |
| Other likelihoods (5) | 2.5 s | 8.4% | 2.5 s | 47.4% | 1.0x |
| Total profiled | 30.8 s | 5.4 s | 5.7x |
Ten iterations (5 warmup + 5 sampling) is far too few for inference, but it is enough to measure the relative cost of model components. Each gradient evaluation exercises every likelihood, and the proportion of time each one consumes stabilizes quickly — it does not depend on how far along the sampler is in adaptation.
You cannot, however, multiply these seconds by the number of iterations in a production run to estimate total wall time. During warmup, NUTS adapts the step size and mass matrix, which changes the number of leapfrog steps per iteration — often dramatically. Early warmup iterations may build small trajectory trees (cheap), while post-adaptation iterations build larger trees (expensive). The cost per iteration is not constant across phases. What is roughly constant is the proportion each component contributes to a single gradient evaluation, which is why the proportions in Table 1 are informative even from a short run.
The proportion shift tells the story. In pure Stan, the two gather-scatter likelihoods consumed 91.6% of profiled time — match segments alone took nearly two-thirds. The custom C++ eliminated most of that overhead: the two likelihoods dropped from 28.3 seconds to 2.8 seconds combined, a 10x reduction.
The five remaining likelihoods were unchanged. They were never dominated by graph overhead, so custom C++ wouldn’t cut the tape. But with the dominant likelihoods out of the way, these untouched components now account for nearly half the profiled time — not because they got slower, but because everything around them got faster.
The practical effect: a model that previously required over four days per run became feasible for iterative development (or production pipeline automation).
7.4 What made these likelihoods good candidates
Both likelihoods shared the same structural properties:
- Repetitive gather-scatter pattern: Each observation loops over a set of indexed parameters. The per-observation graph is identical in structure — only the indices and values change. This makes the analytic gradient straightforward to derive once and apply across all observations.
- Massive observation count: 1.5 million match segments and hundreds of thousands of manager decisions. The per-node overhead, small in isolation, multiplied across this many observations to dominate the computation.
- Moderate per-observation complexity: Each observation gathers a set of indexed parameters and performs a series of standard operations (e.g., exp, log, weighted sums). The analytic gradient is a page of algebra — work, but not a research problem.
- Stable likelihood code: Both likelihoods had stabilized in development. The rest of the model was still evolving. This meant the investment in hand-coded C++ would not be wasted on code that was about to change.
Other likelihoods, by contrast, were expensive for different reasons — they involve expensive special functions and fewer but more complex per-observation computations. Graph overhead was not their bottleneck, so custom C++ would not have helped as much.
8 Verifying your gradients
A custom gradient that is wrong in a subtle way is worse than no custom gradient at all. The sampler will run. It will produce numbers. Those numbers will be wrong in ways that are not obvious from diagnostics. You must verify.
8.1 Method 1: Finite-difference checking
The most direct test. The idea is simple: instead of computing the gradient analytically, approximate it by actually nudging each parameter and measuring how the output changes — the same “nudge and measure” reasoning we used to build intuition earlier, but now used as a verification tool.
For each parameter \(\theta_j\), evaluate the function at two nearby points — one slightly above, one slightly below — and compute the slope between them:
\[ \frac{\partial f}{\partial \theta_j} \approx \frac{f(\theta_j + h) - f(\theta_j - h)}{2h} \]
This is called a central difference approximation. The \(2h\) in the denominator is because we moved a total distance of \(2h\) (from \(-h\) to \(+h\)). A good choice for \(h\) is around \(10^{-7}\) — small enough to approximate the true rate of change, large enough to avoid floating-point rounding issues.
A concrete example: suppose your function returns \(f = 3.7\) at the current parameters. You nudge \(\theta_1\) up by \(h = 10^{-7}\) and get \(f = 3.7000002\), then nudge it down by \(h\) and get \(f = 3.6999998\). The approximate gradient is \((3.7000002 - 3.6999998) / (2 \times 10^{-7}) = 2.0\). If your analytic gradient says the adjoint of \(\theta_1\) should be \(2.0\), you are in good shape. If it says \(1.5\), you have a bug.
Compare the finite-difference approximation against your analytic gradient. They should agree to several decimal places. This test requires no full model run and catches most bugs. Run it on a few random parameter vectors to build confidence.
Stan’s testing framework provides stan::test::expect_near_rel (in test/unit/math/expect_near_rel.hpp) for automated finite-difference checking of analytic gradients.
8.2 Method 2: Trajectory matching
This is the strongest end-to-end check. The idea: if two implementations of the same function produce identical values and identical gradients, then NUTS — which is deterministic given a value and gradient function — will make identical sampling decisions. The lp__ trace will match exactly across iterations.
- Write two Stan models: one using native Stan code, one using your custom C++ function. They must implement the same mathematical function.
- Run both with identical seed, initial values, and data.
- Compare the
lp__column from the output.
If the traces match exactly (not approximately — exactly, to full floating-point precision), your gradient is correct. Any discrepancy, even in the last decimal place, indicates a bug.
This works because NUTS makes discrete decisions (accept/reject, tree-building direction) based on the value and gradient. If both are identical, every decision is identical, and the entire trajectory is reproduced.
Trajectory matching relies on NUTS being deterministic given value and gradient. This holds for Stan’s NUTS implementation. It does not necessarily hold for all samplers or for future sampler implementations.
8.3 Method 3: log_prob and grad_log_prob
CmdStanR and CmdStanPy expose methods to evaluate the log-density and its gradient at a specific point in parameter space. This lets you compare the two implementations directly, without running the sampler:
fit_native <- native_model$sample(...)
fit_custom <- custom_model$sample(...)
# Compare at a specific point
theta <- fit_native$draws(format = "draws_matrix")[1, ]
lp_native <- native_model$log_prob(theta)
lp_custom <- custom_model$log_prob(theta)
grad_native <- native_model$grad_log_prob(theta)
grad_custom <- custom_model$grad_log_prob(theta)
# These should be identical (or very nearly so)
max(abs(grad_native - grad_custom))# After fitting, use log_prob / grad_log_prob methods
# to compare at specific parameter valuesThis complements finite-difference checking by operating at the full-model level rather than the function level.
9 Summary
The recipe for a custom C++ gradient in Stan:
Identify the bottleneck. Profile your model. Find the function where the expression graph overhead dominates.
Derive the gradient analytically. Work out the partial derivative of your function’s output with respect to each input — on paper, before you write any code.
Write the C++ function following the five-step pattern: signature, data shortcut, forward pass on doubles, arena copies, callback with
make_callback_var.Validate. Finite-difference check the gradient. Then run trajectory matching against the native Stan implementation.
Add parallelism if needed. Use
reduce_sumin your Stan model with the custom C++ function handling each slice. Resort to raw TBB only ifreduce_sumdoes not fit.
The technique may seem narrow in scope — it applies to specific, expensive functions, not to entire models. But when it applies, the performance gains can be substantial. A 4x speedup means overnight runs finish before lunch. It means you can iterate on model structure instead of waiting for results. It means the hardware you already have goes further.
The code is more work to write, more work to verify, and more work to maintain than native Stan. That is the tradeoff. Make it for the functions that earn it.
10 Appendix: API reference
A quick reference for the Stan Math functions used in this tutorial.
| Function | Purpose |
|---|---|
make_callback_var(double val, F&& cb) |
Create a var with value val and custom reverse-pass callback cb |
value_of(const T& x) |
Extract the double value from a var or Eigen var container |
to_arena(const T& x) |
Copy x to arena memory so it survives until the reverse pass |
forward_as<var>(x) |
Cast a scalar to var (useful when the type is ambiguous) |
is_var_v<T> |
Compile-time constant: true if T is var, false if double |
scalar_type_t<T> |
Type trait: extracts the scalar type (var or double) from an Eigen type |
check_matching_sizes(...) |
Runtime check that two containers have equal size |
References and further reading
Official documentation
- CmdStan: Using External C++ Code — the primary reference for the
USER_HEADERmechanism, function signature conventions, namespace rules, and gradient specialization. - Stan User’s Guide: Parallelization — covers
reduce_sum,map_rect, and OpenCL parallelization, including grainsize tuning. - Stan User’s Guide: Efficiency Tuning — vectorization, reparameterization, and other standard optimization techniques to apply before reaching for custom C++.
- Stan Functions Reference:
reduce_sum— formal specification ofreduce_sumandreduce_sum_staticsignatures and semantics. - Stan Reference Manual: User-Defined Functions — syntax and type system for the
functions {}block. - Profiling Stan Programs (CmdStanR vignette) — how to add
profile()blocks and interpret timing output. - CmdStanR:
compile()method — documentsuser_header,cpp_options, andstanc_optionsarguments. - CmdStanPy: Using External C++ Functions — step-by-step guide for the Python interface.
- Reduce Sum: A Minimal Example (Stan case study) — a hands-on
reduce_sumtutorial with practical grainsize guidance.
Stan Math source code
callback_vari.hpp— definesmake_callback_var. Key comment: “All captured values must be trivially destructible or they will leak memory.”reverse_pass_callback.hpp— definesreverse_pass_callback, an alternative that registers a callback without creating a newvar.to_arena.hpp— definesto_arena()with overloads for scalars, Eigen types, andstd::vector.- Common Pitfalls (Stan Math contributor docs) — explains memory lifetime issues with
make_callback_var,reverse_pass_callback,to_arena, andmake_chainable_ptr. - Adding New Functions with Known Gradients (Stan Math contributor docs) — how Stan’s developers add hand-coded gradients to the library itself.
Stan Discourse
- Documentation of external C++ with precomputed gradients needs update? — clarifies the relationship between
make_callback_varand the olderprecomputed_gradients()API. - External C++ Code with Precomputed Gradients — practical discussion of compiling models with custom gradient code.
- Modifying external C++ function example to return analytical gradient — a user works through writing a custom C++ gradient step by step.
- Vector-valued functions with manual gradients in external C++ — extends the pattern to functions returning vectors rather than scalars.
- Custom function with known gradients — early thread on the general approach.
- Custom C++ using CmdStanR — covers namespace requirements when using CmdStanR.
- Includes in USER_HEADER — what can be
#include-ed inside a user header file.