Hand-Coded C++ Gradients for Stan

Bypassing the expression graph with make_callback_var

Author

Dr. Scott Spencer

Published

February 15, 2026

1 When Stan is already fast — and you need it faster

Suppose you have a Stan model that works. You have vectorized your operations, replaced loops with matrix algebra where possible, and followed the efficiency advice in the Stan User’s Guide. You are running on good hardware with enough cores. The model produces valid inferences.

And it is still too slow.

This happens. Not because Stan is slow — Stan compiles your model to C++, and the compiled code uses Eigen for linear algebra, Intel TBB for threading, and the Stan Math library for numerics. When you write dot_product(x, w) in Stan, the code that executes is already optimized C++ calling into production-grade libraries. You are not going to beat it by rewriting the arithmetic.

What you can beat is the bookkeeping. Stan’s automatic differentiation system is general-purpose: every arithmetic operation on a parameter records a node in an internal data structure called the expression graph. For most models this overhead is negligible. But when a function performs millions of repetitive operations — looping over hundreds of thousands of observations, gathering parameters by index, computing the same formula at each step — the per-node bookkeeping can dominate the actual arithmetic. The math is fast. The tracking machinery around it is not.

This tutorial teaches you how to eliminate that overhead for targeted sections of your model. The idea is not to replace Stan’s C++ with different C++ — the arithmetic stays the same. It is to replace Stan’s general-purpose gradient tracking with a hand-crafted gradient computation that exploits the specific structure of your function. You tell Stan: “I know the gradient of this function analytically — skip the graph, here is the answer.”

The technique uses a single entry point in the Stan Math library called make_callback_var (defined in callback_vari.hpp in the Stan Math source). It is not documented in the Stan User’s Guide. It is not complicated. But getting it right requires understanding a few ideas clearly, and the existing documentation consists of source-code comments and scattered forum posts.

We will build up from the simplest possible example — a weighted sum — to indexed gather operations, and then to parallel execution. By the end, you will understand the pattern well enough to apply it to your own models.

1.1 What this tutorial assumes

You are comfortable writing Stan models. You have a working CmdStan, CmdStanR, or CmdStanPy installation. You have encountered the idea of automatic differentiation and know, at least roughly, that Stan computes gradients for you. You do not need to have written C++ before, though you will be writing some here.

This tutorial is not a general introduction to automatic differentiation. It is not about replacing Stan’s autodiff everywhere. It is about replacing it in specific, expensive functions that you have already identified as bottlenecks — through profiling, through timing, or through the experience of watching a sampler crawl.

1.2 Why doesn’t Stan already do this?

If hand-coding the gradient can be so much faster, a natural question: why doesn’t Stan do it automatically?

The short answer is that Stan does — for its own built-in functions. When you call normal_lpdf or log_sum_exp, you are not hitting the general-purpose autodiff system. Stan’s developers have written hand-coded gradients for hundreds of built-in functions, using exactly the technique described in this tutorial. Those functions create a single node in the expression graph and compute their gradients analytically in a callback. This is why Stan’s built-in functions are fast.

What Stan cannot do is write those hand-coded gradients for your functions. When you write a custom computation in the Stan language — a loop over observations, a series of indexed lookups, a particular formula — Stan’s only option is the general-purpose approach: record every operation as a node, build the graph, and traverse it later. Stan does not know the mathematical structure of your computation. You do.

Could a sufficiently smart compiler figure it out? In principle, partially. A compiler could recognize common patterns — “this is a dot product over gathered indices” or “this is a log-sum-exp reduction” — and replace them with optimized, hand-coded gradient implementations. Some of this happens already: Stan’s compiler performs expression optimization and can fuse certain operations. But the space of possible user-written functions is vast, and recognizing that a particular 50-line loop is mathematically equivalent to a function with a known analytic gradient is, in the general case, an unsolved problem in compiler design.

So the situation is: Stan gives you the general-purpose system (correct for any differentiable function, no calculus required from the user) and an escape hatch (make_callback_var) for the cases where you know enough about your function’s math to do better. The escape hatch exists because Stan’s developers use it themselves — they are just giving you access to the same tool. (The Stan Math contributor guide on adding new gradients describes this pattern from the developer’s perspective.)

1.3 When this technique applies

This technique requires that you can write down the gradient of your function analytically — that is, as an explicit formula involving the inputs and any intermediates you computed during the forward pass. You will implement that formula in C++ and hand it to Stan. If you cannot express the gradient as a formula, you cannot use this approach.

The good news is that most functions encountered in statistical modeling have analytic gradients. If your function is built from standard operations — addition, multiplication, exp, log, pow, dot products, matrix multiplies, sums, reductions — then its gradient can be derived by applying the chain rule (remember that term from high school Calculus?) step by step. That is exactly what Stan’s autodiff does automatically; you are simply doing it yourself on paper.

A practical test: can you write the function as a composition of operations that each have a known derivative? If so, the chain rule gives you the gradient of the whole function. For example:

  • \(f(\mathbf{x}) = \sum_i x_i w_i\) — each term’s derivative with respect to \(x_i\) is \(w_i\). Analytic gradient: yes.
  • \(f(\mathbf{x}) = \log \sum_i \exp(x_i)\) — the derivative with respect to \(x_i\) is the softmax weight \(\exp(x_i) / \sum_j \exp(x_j)\). Analytic gradient: yes.
  • A function involving if statements on parameter values (like a piecewise linear function) — the gradient exists everywhere except at the breakpoints, and is a simple formula on each piece. Analytic gradient: yes, in practice.

Functions where this technique does not readily apply:

  • Functions defined by iterative algorithms where the number of iterations depends on the input (e.g., root-finding by Newton’s method, optimization loops). The gradient of the output with respect to the input exists in principle (via the implicit function theorem), but it is not a simple formula you can write in a callback.
  • Functions involving discrete combinatorial operations (sorting, argmax) where the output changes discontinuously with the input.

If you can derive the gradient on paper — even if the derivation takes some work — this technique applies. If the derivation requires specialized mathematical machinery (implicit differentiation, adjoint methods for ODEs), you are likely better served by Stan’s built-in support for those cases (algebra_solver, ode_*), which already implements the specialized math.

2 How Stan computes gradients

To understand what we are replacing, we need to understand what Stan does by default. This requires a few ideas that we will build up one at a time.

2.1 Two kinds of numbers

Stan works with two kinds of numbers internally. The first is the ordinary floating-point number — a double in C++. When you declare something in a data block or transformed data block, Stan stores it as a double. It is just a number. Nothing special happens when you do arithmetic with it.

The second is Stan’s gradient-tracking type, called var (short for “variable”). When you declare a parameter — anything in the parameters block, or derived from parameters in transformed parameters or model — Stan wraps it in a var. A var holds a double value inside, but it also carries machinery for tracking how that value was computed, so that Stan can later figure out how the final answer (the log-posterior) depends on it.

The distinction matters because the entire technique in this tutorial rests on it: we strip the var wrappers off, do the arithmetic on plain double values (fast, no tracking), and then tell Stan the gradient ourselves.

2.2 The expression graph

When your model evaluates an expression involving var values — that is, parameters — Stan does not just compute the answer. It also builds a graph, called the expression graph (sometimes called the “tape”), that records every operation and its inputs.

Consider a simple expression where \(a\), \(b\), and \(c\) are parameters:

\[ y = a \cdot b + c \]

Stan computes the value of \(y\), but it also builds something like this:

[a] [b]      [c]
 \  /          |
 [×]          |
   \         /
    [+] → y

Each box is a node. Each node remembers its inputs and a rule for computing how sensitive the output is to changes in that node’s result. (We will make this precise shortly.)

For small expressions, the overhead of building this graph is negligible. The trouble starts when the graph grows large.

2.3 When the graph becomes the bottleneck

Imagine a model with 300,000 observations. Each observation requires gathering a handful of parameters by index, computing a log-likelihood contribution, and adding it to the target. If each iteration of that loop creates 30 nodes, the graph contains 9 million nodes — each one allocated, linked, and later traversed during the gradient computation.

The actual floating-point arithmetic might take microseconds per observation. The graph bookkeeping can take significantly longer.

This is the situation where hand-coded C++ helps. Instead of 9 million nodes, you create one. That single node stores a callback function that you write — a function that knows how to compute the gradients for the entire operation. Stan calls your callback during the gradient computation. The arithmetic is the same. The bookkeeping vanishes.

2.4 The forward and reverse passes

Stan’s gradient computation happens in two phases.

The forward pass evaluates the model to get a numeric answer — the log-posterior density at the current parameter values. This is the number that appears as lp__ in Stan’s output. During the forward pass, the expression graph is built.

The reverse pass walks the graph backward from the output to the inputs. At each node, it figures out how much the final output would change if that node’s value changed by a tiny amount. This is the information the sampler needs: NUTS explores parameter space by following the gradient of the log-posterior, the way a hiker might follow the slope of a hillside. Without knowing which direction is “uphill” — which parameter adjustments increase the log-posterior and by how much — the sampler cannot take efficient steps. The gradient tells it where to go.

The quantity computed at each node during the reverse pass is called the adjoint of that node.

Let us make this concrete. Suppose the log-posterior is currently \(L = 10.0\), and somewhere in the computation there is an intermediate variable \(t = 3.0\) (a var in Stan’s type system). If nudging \(t\) from \(3.0\) to \(3.001\) would change \(L\) from \(10.0\) to \(10.005\), then the adjoint of \(t\) is \(0.005 / 0.001 = 5\). In calculus notation, we write this as a partial derivative — the rate of change of \(L\) with respect to \(t\) while holding everything else fixed:

\[ \frac{\partial L}{\partial t} = 5 \]

The symbol \(\partial L / \partial t\) is read “the partial derivative of \(L\) with respect to \(t\).” It answers a specific question: if I nudge \(t\) by a tiny amount, how much does \(L\) change, per unit of nudge? The adjoint is just a name for this quantity in the context of reverse-mode autodiff.

2.5 The chain rule: how adjoints propagate

The reverse pass computes adjoints by working backward through the graph, applying a rule from calculus called the chain rule. The chain rule says: if \(L\) depends on \(t\), and \(t\) depends on \(a\), then the rate at which \(L\) changes with \(a\) equals the rate at which \(L\) changes with \(t\), multiplied by the rate at which \(t\) changes with \(a\):

\[ \frac{\partial L}{\partial a} = \frac{\partial L}{\partial t} \cdot \frac{\partial t}{\partial a} \]

A concrete example makes this mechanical. Take \(y = a \cdot b + c\) with \(a = 3\), \(b = 2\), \(c = 1\), so \(y = 7\). Suppose the rest of the model is such that \(\partial L / \partial y = 4\) — that is, the upstream adjoint arriving at our expression is 4.

The intermediate product is \(t_1 = a \cdot b = 6\). The addition \(y = t_1 + c\) has the property that nudging either input by \(\epsilon\) nudges \(y\) by \(\epsilon\), so \(\partial y / \partial t_1 = 1\) and \(\partial y / \partial c = 1\).

Now apply the chain rule:

  • Adjoint of \(t_1\): \(\frac{\partial L}{\partial t_1} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial t_1} = 4 \cdot 1 = 4\)
  • Adjoint of \(c\): \(\frac{\partial L}{\partial c} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial c} = 4 \cdot 1 = 4\)

For the multiplication \(t_1 = a \cdot b\): nudging \(a\) from 3 to 3.001 changes \(t_1\) from 6 to 6.002, so \(\partial t_1 / \partial a = b = 2\). Similarly, \(\partial t_1 / \partial b = a = 3\). The chain rule gives:

  • Adjoint of \(a\): \(\frac{\partial L}{\partial a} = \frac{\partial L}{\partial t_1} \cdot \frac{\partial t_1}{\partial a} = 4 \cdot 2 = 8\)
  • Adjoint of \(b\): \(\frac{\partial L}{\partial b} = \frac{\partial L}{\partial t_1} \cdot \frac{\partial t_1}{\partial b} = 4 \cdot 3 = 12\)

You can verify: nudging \(a\) from 3 to 3.001 changes \(y\) from 7 to 7.002, and if \(\partial L / \partial y = 4\), then \(L\) changes by \(0.002 \times 4 = 0.008\), giving a rate of \(0.008 / 0.001 = 8\). It checks out.

Stan does this computation automatically for every node in the graph. When you write custom C++, you do it yourself — once, for the entire function — and hand the results directly to the inputs. The arithmetic is the same. You are just doing it by hand instead of asking Stan to build a graph and traverse it.

3 A weighted sum

We begin with the simplest possible case: a function that takes two vectors and returns their weighted sum (a dot product).

\[ f(\mathbf{x}, \mathbf{w}) = \sum_{i=1}^{N} x_i \, w_i \]

In Stan, you would write dot_product(x, w) or sum(x .* w) and the built-in autodiff would handle the gradient. We are reimplementing it in C++ not because it needs reimplementing, but because the pattern is clear enough to teach the mechanics without distraction.

3.1 The Stan declaration

In your .stan file, declare the function with no body (see the Stan Reference Manual on user-defined functions for the full syntax):

functions {
  real weighted_sum(vector x, vector w);
}

The missing body tells the Stan compiler to look for a C++ implementation in a user-provided header file. You will supply this header at compile time. (The CmdStan guide on external C++ covers the mechanics of how this linkage works.)

3.2 The C++ implementation

Create a C++ header file, named something informative, here we’ll call it user_header.hpp. This is the complete file for our first example:

#ifndef USER_HEADER_HPP
#define USER_HEADER_HPP

#include <stan/math.hpp>
#include <ostream>

namespace stan_model {

template <typename EigVec1, typename EigVec2>
auto weighted_sum(const EigVec1& x,
                  const EigVec2& w,
                  std::ostream* pstream__) {

  // Extract the scalar types: var (parameter) or double (data)
  using T_x = stan::scalar_type_t<EigVec1>;
  using T_w = stan::scalar_type_t<EigVec2>;

  // If both inputs are data (double), no gradient needed
  if constexpr (!stan::is_var_v<T_x> && !stan::is_var_v<T_w>) {
    const auto& x_val = stan::math::value_of(x);
    const auto& w_val = stan::math::value_of(w);
    return x_val.dot(w_val);
  } else {
    // --- Forward pass: compute on plain doubles ---
    const auto& x_val = stan::math::value_of(x);
    const auto& w_val = stan::math::value_of(w);
    double result = x_val.dot(w_val);

    // --- Keep var inputs alive until the reverse pass ---
    auto x_arena = stan::math::to_arena(x);
    auto w_arena = stan::math::to_arena(w);

    // --- Return a var with a custom reverse pass ---
    return stan::math::make_callback_var(
      result,
      [x_arena, w_arena](auto& ret) {
        double upstream = ret.adj();

        if constexpr (stan::is_var_v<T_x>) {
          x_arena.adj().array() += upstream * stan::math::value_of(w_arena).array();
        }
        if constexpr (stan::is_var_v<T_w>) {
          w_arena.adj().array() += upstream * stan::math::value_of(x_arena).array();
        }
      }
    );
  }
}

} // namespace stan_model
#endif

That is a lot of code for a dot product. Let us walk through it piece by piece.

3.2.1 The function signature

template <typename EigVec1, typename EigVec2>
auto weighted_sum(const EigVec1& x,
                  const EigVec2& w,
                  std::ostream* pstream__) {

Stan’s compiler generates calls to your C++ function using template types. A vector argument in Stan might arrive as an Eigen::Matrix<var, -1, 1> (when the vector contains parameters) or an Eigen::Matrix<double, -1, 1> (when it contains data). The templates let a single function definition handle both cases — the compiler generates the appropriate version for each combination of input types.

The pstream__ argument is Stan’s output stream. Every user-defined function receives it as the last argument. You will rarely use it, but it must be there. (The CmdStan external C++ guide documents the full signature conventions, including the pstream__ parameter and the required namespace.)

The return type is auto because it depends on the inputs: double when all inputs are data, var when any input is a parameter.

3.2.2 The data-only shortcut

if constexpr (!stan::is_var_v<T_x> && !stan::is_var_v<T_w>) {
  const auto& x_val = stan::math::value_of(x);
  const auto& w_val = stan::math::value_of(w);
  return x_val.dot(w_val);
}

This is the first thing to understand, and the most common source of confusion. When Stan calls your function with data-only arguments — in the transformed data block, for example — the inputs are plain double values, not var. There are no parameters, no gradients to compute, and no expression graph. You just compute the answer and return a double.

The if constexpr is a compile-time check. It asks: “are any of the inputs var?” If not, the compiler generates a simple version of the function that does plain arithmetic. If so, it generates the version with gradient tracking. Both versions exist in the compiled code; the compiler discards the one that does not apply for each call site. It costs nothing at runtime.

value_of() extracts the numeric value from whatever type it receives. On a double, it returns the double unchanged. On a var, it returns the double stored inside. We use it uniformly so the same code works in both branches.

3.2.3 The forward pass

const auto& x_val = stan::math::value_of(x);
const auto& w_val = stan::math::value_of(w);
double result = x_val.dot(w_val);

We strip the var wrappers and compute on plain double values. No nodes are created. No graph is built. This is just arithmetic.

3.2.4 Keeping inputs alive

auto x_arena = stan::math::to_arena(x);
auto w_arena = stan::math::to_arena(w);

The reverse pass does not happen now. It happens later — after the entire forward pass of the model is complete. By that time, the local variables x and w may have gone out of scope, and their memory may have been reused.

to_arena() copies the data into Stan’s arena memory — a fast allocator that is freed all at once after the reverse pass completes. This guarantees that x_arena and w_arena are still valid when the callback fires.

3.2.5 The callback

return stan::math::make_callback_var(
  result,
  [x_arena, w_arena](auto& ret) {
    double upstream = ret.adj();

    if constexpr (stan::is_var_v<T_x>) {
      x_arena.adj().array()
        += upstream * stan::math::value_of(w_arena).array();
    }
    if constexpr (stan::is_var_v<T_w>) {
      w_arena.adj().array()
        += upstream * stan::math::value_of(x_arena).array();
    }
  }
);

make_callback_var does two things:

  1. It creates a var whose value is result.
  2. It registers a callback (a lambda function) that Stan will call during the reverse pass.

When the reverse pass reaches this node, Stan calls the lambda. The argument ret is the var that was returned — our single node in the graph. ret.adj() is the upstream adjoint: how much the log-posterior would change per unit change in our function’s return value.

Now we need the gradient of our function — the partial derivative of the output with respect to each input element. Our function is \(f(\mathbf{x}, \mathbf{w}) = \sum_i x_i w_i\). To find how the output changes when we nudge one input, say \(x_1\), hold everything else fixed: if \(\mathbf{x} = (3, 5)\) and \(\mathbf{w} = (2, 4)\), then \(f = 3 \cdot 2 + 5 \cdot 4 = 26\). Nudging \(x_1\) from 3 to 3.001 gives \(f = 3.001 \cdot 2 + 5 \cdot 4 = 26.002\). The rate of change is \(0.002 / 0.001 = 2\), which is \(w_1\). In general:

\[ \frac{\partial f}{\partial x_i} = w_i \qquad\qquad \frac{\partial f}{\partial w_i} = x_i \]

The chain rule then gives us the adjoint of each input — multiply the upstream adjoint (how much \(L\) changes per unit change in \(f\)) by the local derivative (how much \(f\) changes per unit change in the input):

\[ \text{adj}(x_i) \mathrel{+}= \frac{\partial L}{\partial f} \cdot w_i \qquad\qquad \text{adj}(w_i) \mathrel{+}= \frac{\partial L}{\partial f} \cdot x_i \]

That is exactly what the code does. The += is important — this input might be used elsewhere in the model, so we accumulate into the adjoint rather than overwriting it.

The if constexpr guards ensure we only touch adjoints of var inputs. If w is data (a double vector), it has no .adj() to write to — data does not participate in gradients — and attempting to access it would fail at compile time.

3.3 Compiling and running

model <- cmdstan_model(
  'model.stan',
  stanc_options = list("allow-undefined" = TRUE),
  user_header = normalizePath('user_header.hpp')
)
model = CmdStanModel(
  stan_file='model.stan',
  stanc_options={'allow-undefined': True},
  user_header='user_header.hpp'
)

Add to your make/local file:

STANCFLAGS += --allow-undefined
USER_HEADER = /full/path/to/user_header.hpp

The --allow-undefined flag tells the Stan compiler to accept function declarations without bodies. The user_header option tells CmdStan to include your C++ file during compilation (see the CmdStanR and CmdStanPy documentation for the interface-specific arguments). Stan expects a single header file — but if your implementation grows large enough to split across files, you can use one main header that #includes the others:

// user_header.hpp
#ifndef USER_HEADER_HPP
#define USER_HEADER_HPP

#include "weighted_sum.hpp"
#include "gathered_lse.hpp"
#include "my_likelihood.hpp"

#endif

Each included file contains one or more function implementations. Stan only sees the single user_header.hpp; you organize the internals however you like.

3.4 The recipe

Every custom C++ function for Stan follows the same structure:

  1. Signature: template over input types, return auto, accept pstream__ as the last argument.
  2. Data shortcut: if no input is a var, compute and return a double.
  3. Forward pass: strip inputs to double with value_of(), compute the answer.
  4. Arena copy: save var inputs with to_arena().
  5. Callback: return make_callback_var(value, lambda), where the lambda reads the upstream adjoint and distributes it to each input’s .adj() using the chain rule.

This recipe does not change. The only thing that varies between functions is the arithmetic in steps 3 and 5.

4 A log-sum-exp gather

The weighted sum was clean because the gradient was trivial. Now we tackle a function with nonlinear operations and index arrays — the kind of pattern that actually appears in real models.

Consider a model where you need the log-sum-exp of a subset of parameters, selected by an index array:

\[ f(\mathbf{x}, \text{idx}) = \log \sum_{i=1}^{K} \exp\!\big(x_{\text{idx}[i]}\big) \]

This is common in mixture models, hierarchical models with group-level indexing, and any model where observations map to subsets of parameters.

4.1 The Stan declaration

functions {
  real gathered_lse(vector x, array[] int idx);
}

The vector x holds the full parameter vector. The integer array idx selects which elements to include, using Stan’s 1-based indexing.

4.2 The math

The forward pass computes the log-sum-exp. For numerical stability, we use the standard trick of subtracting the maximum:

\[ m = \max_i\, x_{\text{idx}[i]} \qquad f = m + \log \sum_{i=1}^{K} \exp\!\big(x_{\text{idx}[i]} - m\big) \]

For the reverse pass, we need the partial derivative of the log-sum-exp with respect to each gathered element — how much does the output change when we nudge one input?

A concrete example helps. Suppose \(\mathbf{x}_{\text{idx}} = (1, 3)\). Then \(f = \log(e^1 + e^3) = \log(2.718 + 20.086) = \log(22.804) \approx 3.127\). If we nudge the second element from 3 to 3.001: \(f' = \log(e^1 + e^{3.001}) \approx 3.12788\). The rate of change is about \(0.88\), which is \(e^3 / (e^1 + e^3)\) — the softmax weight of the second element.

In general, the partial derivative turns out to be the softmax — the exponential of each element divided by the sum of all exponentials:

\[ \frac{\partial f}{\partial x_{\text{idx}[i]}} = \frac{\exp(x_{\text{idx}[i]})}{\sum_{j=1}^{K} \exp(x_{\text{idx}[j]})} = \text{softmax}(\mathbf{x}_{\text{idx}})_i \]

We will need these softmax weights during the reverse pass, so we compute and store them during the forward pass.

4.3 The C++ implementation

template <typename EigVec>
auto gathered_lse(const EigVec& x,
                  const std::vector<int>& idx,
                  std::ostream* pstream__) {

  using T_x = stan::scalar_type_t<EigVec>;
  const int K = idx.size();

  // Extract doubles
  const auto& x_val = stan::math::value_of(x);

  // Forward pass: numerically stable log-sum-exp
  double max_val = -std::numeric_limits<double>::infinity();
  for (int i = 0; i < K; ++i) {
    max_val = std::max(max_val, x_val[idx[i] - 1]);  // 1-based to 0-based
  }

  Eigen::VectorXd weights(K);
  double sum_exp = 0.0;
  for (int i = 0; i < K; ++i) {
    weights[i] = std::exp(x_val[idx[i] - 1] - max_val);
    sum_exp += weights[i];
  }

  double result = max_val + std::log(sum_exp);

  // Softmax weights (needed for reverse pass)
  weights /= sum_exp;

  if constexpr (!stan::is_var_v<T_x>) {
    return result;
  } else {
    auto x_arena = stan::math::to_arena(x);

    // Move weights to arena so the callback can read them
    auto weights_arena = stan::math::to_arena(weights);

    return stan::math::make_callback_var(
      result,
      [x_arena, weights_arena, idx, K](auto& ret) {
        double upstream = ret.adj();

        for (int i = 0; i < K; ++i) {
          x_arena.adj()[idx[i] - 1] += upstream * weights_arena[i];
        }
      }
    );
  }
}

The structure is the same recipe from the weighted sum. The only differences are the arithmetic (log-sum-exp instead of dot product) and the gradient (softmax weights instead of the other vector’s values). Let us look at the parts that are new.

4.4 Index translation

Stan uses 1-based indexing. C++ uses 0-based. Every access into a C++ Eigen vector with a Stan index requires subtracting 1:

x_val[idx[i] - 1]

Getting this wrong produces silent garbage — the function reads from the wrong memory location, returns a plausible-looking number, and the sampler wanders off without complaint. Check your indices carefully.

4.5 Repeated indices and the += rule

Suppose idx = {3, 3, 5}. The parameter x[3] appears twice. During the reverse pass, it must receive the adjoint contributions from both appearances:

x_arena.adj()[idx[0] - 1] += upstream * weights_arena[0];  // x[3]
x_arena.adj()[idx[1] - 1] += upstream * weights_arena[1];  // x[3] again

Both lines write to x_arena.adj()[2] — and the += ensures the contributions accumulate. If you used = instead, the second write would overwrite the first. The gradient would be wrong, the sampler would silently produce biased results, and nothing would warn you.

This is the most common bug in hand-coded gradients with index arrays. Use += for adjoints. Always.

4.6 Storing intermediates

The softmax weights are computed during the forward pass and consumed during the reverse pass. These two events happen at different times — the forward pass runs now, and the reverse pass runs after the entire model has been evaluated.

We solve this by storing the weights in arena memory:

auto weights_arena = stan::math::to_arena(weights);

The lambda captures weights_arena by value (it captures the arena pointer, which is cheap). When the reverse pass eventually calls the lambda, the weights are still there.

This is a general principle: anything the reverse pass needs that is not already stored in a var must be copied to the arena. Intermediate doubles, precomputed gradient factors, scratch buffers — all of it. (The Stan Math source warns explicitly: “All captured values must be trivially destructible or they will leak memory. to_arena() function can be used to ensure that.” The Stan Math common pitfalls guide elaborates on these lifetime issues.)

4.7 Defensive checks

Stan models that segfault are painful to debug. A few cheap checks at the top of your function can save hours:

// Check that indices are within bounds
for (int i = 0; i < K; ++i) {
  if (idx[i] < 1 || idx[i] > x_val.size()) {
    throw std::out_of_range("gathered_lse: index out of bounds");
  }
}

Stan also provides stan::math::check_* functions (check_matching_sizes, check_positive, etc.) for common validations. These produce informative error messages that include the function name and the offending value.

5 Adding parallelism with reduce_sum

You have a custom C++ function that eliminates the expression graph overhead. The model is faster. But the function runs on a single core, and the computation is embarrassingly parallel — each observation or group contributes independently to the log-posterior.

The instinct might be to reach for Intel TBB directly, writing parallel_reduce bodies with thread-local adjoint buffers. That works, and we will cover it in Section 6. But for most problems, there is a much simpler approach: let Stan handle the threading.

Stan’s reduce_sum function partitions a computation across threads. It handles work distribution, thread management, and — critically — the safe aggregation of adjoints across threads. You do not touch TBB. You do not manage thread-local buffers. You write a single-threaded partial sum function, and reduce_sum does the rest. (The Stan case study on reduce_sum walks through a minimal example.)

The key insight is that reduce_sum and make_callback_var solve different problems. reduce_sum solves parallelism. make_callback_var solves expression graph overhead. You can — and often should — use both.

5.1 The pattern

In your .stan file:

functions {
  real my_fast_lp(vector theta, array[] int obs_idx, ...);

  real partial_sum_lpmf(array[] int slice,
                        int start, int end,
                        vector theta, ...) {
    return my_fast_lp(theta, slice, ...);
  }
}

model {
  target += reduce_sum(partial_sum_lpmf, obs_group_ids,
                       grainsize, theta, ...);
}

The function my_fast_lp is your custom C++ implementation — it uses make_callback_var to bypass the expression graph, exactly as in the previous two examples. The function partial_sum_lpmf is a thin Stan wrapper that reduce_sum calls for each slice of work.

Stan partitions obs_group_ids into slices, sends each slice to a thread, and combines the results. Your C++ function does the heavy arithmetic on each slice without building any graph. The combination of both techniques gives you parallel execution and minimal graph overhead.

5.2 Why reduce_sum handles the hard parts

When multiple threads modify the same parameter’s adjoint simultaneously, you get a race condition — a class of bug where the result depends on the unpredictable timing of threads. Thread A reads the adjoint, adds its contribution, and writes it back — but Thread B read the same value before A wrote, so B’s write overwrites A’s contribution. The adjoint ends up wrong, and the error changes from run to run.

reduce_sum avoids this by giving each thread its own copy of the parameters’ adjoint storage. When all threads finish, it merges the results. You do not need to think about any of this — it is handled internally.

If you wrote TBB code yourself, you would need to implement this merging manually. For most models, there is no reason to.

5.3 Grainsize tuning

The grainsize parameter controls how reduce_sum partitions the work. Each slice contains at least grainsize elements. The tradeoffs:

  • Too small: thread overhead dominates. Each slice does little work, but the cost of dispatching and merging is nontrivial.
  • Too large: threads sit idle. If you have 8 cores and a grainsize that creates only 3 slices, 5 cores do nothing.

A good starting point: divide the number of observations (or whatever natural unit you are slicing) by the number of threads per chain. If you have 10,000 observations and 4 threads, start with grainsize = 2500. The reasoning: if the workload is roughly balanced across observations, this gives each thread exactly one slice, which minimizes the overhead of dispatching and combining results.

Then halve the grainsize and run again. Halve it once more. Compare the wall-clock times. If the observations are well balanced — roughly equal work per observation — the first value (one slice per thread) is usually fastest. If the workload is unbalanced — some observations are much more expensive than others — smaller grainsizes let the scheduler even things out by giving threads that finish early another slice to work on.

Setting grainsize = 1 tells Stan to tune automatically at runtime. This is a reasonable fallback if you do not want to experiment, but in practice the manual approach above often wins, especially once the expression graph overhead is gone and the per-observation cost is predictable.

Tip

To benchmark grainsize, fix the random seed (seed), the initialization seed (init), and run a fixed number of iterations (e.g., iter_warmup = 200, iter_sampling = 200). This gives you comparable wall-clock times across runs without noise from different trajectories.

5.4 Compiling with threading

model <- cmdstan_model(
  'model.stan',
  cpp_options = list(stan_threads = TRUE),
  stanc_options = list("allow-undefined" = TRUE),
  user_header = normalizePath('user_header.hpp')
)

fit <- model$sample(
  data = stan_data,
  threads_per_chain = 4
)
model = CmdStanModel(
  stan_file='model.stan',
  cpp_options={'stan_threads': True},
  stanc_options={'allow-undefined': True},
  user_header='user_header.hpp'
)

fit = model.sample(
  data=stan_data,
  threads_per_chain=4
)
STANCFLAGS += --allow-undefined
USER_HEADER = /full/path/to/user_header.hpp
STAN_THREADS = true

Then set the environment variable before running:

export STAN_NUM_THREADS=4
./model sample data file=data.json

6 Raw TBB parallelism

In rare cases, reduce_sum is not the right tool. Perhaps the computation does not decompose into independent additive terms. Perhaps you need fine-grained control over how the forward and reverse passes are parallelized separately. Perhaps the partial sums share expensive intermediate computations that you want to cache.

In these cases, you can use Intel TBB directly from within your make_callback_var callback. This is harder to get right and harder to maintain. Try reduce_sum first. Come here only if you have a specific reason. (For more discussion of advanced patterns, including vector-valued returns, see this Discourse thread.)

6.1 The problem with shared adjoints

Consider a parallel reverse pass where 8 threads each need to add their gradient contributions to the same parameter vector’s adjoints. Without synchronization, threads overwrite each other’s work — the race condition described in the reduce_sum section. The standard solution is to give each thread its own adjoint buffer, then merge them after all threads finish.

6.2 The parallel_reduce pattern

TBB’s parallel_reduce is designed for exactly this pattern. You define a “body” class with three capabilities: process a range of work, accumulate results into local storage, and merge two bodies together.

The skeleton looks like this:

struct ReverseBody {
  // Shared (read-only) data
  const Eigen::VectorXd& weights;
  const std::vector<int>& idx;
  double upstream;

  // Thread-local adjoint buffer
  Eigen::VectorXd local_adj;

  ReverseBody(/* shared data */, int n_params)
    : /* init shared refs */,
      local_adj(Eigen::VectorXd::Zero(n_params)) {}

  // Splitting constructor — TBB calls this to create a new thread's body
  ReverseBody(ReverseBody& other, tbb::split)
    : /* copy shared refs from other */,
      local_adj(Eigen::VectorXd::Zero(other.local_adj.size())) {}

  // Process a range of work
  void operator()(const tbb::blocked_range<int>& range) {
    for (int i = range.begin(); i < range.end(); ++i) {
      local_adj[idx[i] - 1] += upstream * weights[i];
    }
  }

  // Merge another body's results into this one
  void join(const ReverseBody& other) {
    local_adj += other.local_adj;
  }
};

Each thread gets its own ReverseBody with a zeroed-out local_adj buffer. It writes only to its own buffer — no conflicts. After all threads finish, TBB calls join() to merge the buffers pairwise until a single body holds all the contributions. You then add those into the var inputs’ actual adjoints:

x_arena.adj() += body.local_adj;

This final step is single-threaded, which is fine — it is a single vector addition.

6.3 Cache lifetime with shared_ptr

The forward and reverse passes happen at different times. If the forward pass computes expensive intermediates (softmax weights, gradient factors, scratch arrays), the reverse pass callback needs access to them.

Arena memory works for Eigen vectors and scalars. For more complex structures — or when you want to share a cache between the forward body and the reverse body — std::shared_ptr is useful:

auto cache = std::make_shared<ForwardCache>();
// ... populate cache during forward pass ...

return stan::math::make_callback_var(
  result,
  [x_arena, cache](auto& ret) {
    // cache is still alive — shared_ptr prevents deallocation
    // ... reverse pass using cache->weights, etc. ...
  }
);

The shared_ptr is a reference-counted pointer — it keeps the underlying data alive as long as at least one reference exists. The lambda captures the pointer by value (cheap — it copies the pointer, not the data). When the lambda is destroyed after the reverse pass, the reference count drops to zero and the cache is freed.

6.4 When to use raw TBB

Use raw TBB only when:

  • The computation cannot be expressed as a sum of independent terms (so reduce_sum does not apply).
  • You need to parallelize the reverse pass differently from the forward pass.
  • You have verified, with profiling, that reduce_sum leaves performance on the table for your specific case.

For the vast majority of models, the combination of make_callback_var (for graph elimination) and reduce_sum (for parallelism) is sufficient.

7 A real-world case study

The techniques in this tutorial were developed for a model of global football — estimating player abilities across more than 85 leagues and a decade of seasons.

7.1 The model

The model is a large Bayesian hierarchical model with several likelihoods, each contributing to the joint posterior. The data includes match-level and player-level observations across a decade of competition. The two most expensive likelihoods — and the ones we optimized — both follow the gather-scatter pattern that makes this technique effective.

  1. Match segments: Nearly 1.5 million segments, where a segment is a stretch of play between match events. Each segment gathers multi-correlated ability parameters for every player on the pitch — all 22 players across both teams — plus league-level and match-level covariates, and computes a segment-level log-likelihood. The Stan implementation involved two stages: building the segment data structures (gathering parameters by index, computing team-level quantities) and evaluating the log-probability. Both stages create expression graph nodes — hundreds per segment, hundreds of millions in total.

  2. Manager decisions: An observation at every match start and every intra-match decision, modeling the sequential decisions a manager makes — think player selection or usage — as a function of player abilities and match context. Each decision gathers parameters from the available pool and computes a probability. The same gather-scatter structure: loop over decisions, gather indexed parameters, evaluate, accumulate.

The model includes several other likelihoods, but these two dominated the wall time. The Stan code was already well optimized — both likelihoods were parallelized with reduce_sum, the code was vectorized where possible, and unnecessary allocations had been eliminated. The model was, in other words, already at the point this tutorial begins: Stan’s standard optimizations had been applied, and the bottleneck was the expression graph itself. With 1.5 million segments each creating hundreds of nodes, the graph grew to hundreds of millions of nodes per gradient evaluation.

Running this already-optimized model in the cloud on Intel Granite Rapids hardware — 72 physical cores per machine, one machine per chain, four chains — the wall time could exceed four days. The model was under active development (priors, covariates, model structure were all still evolving), but these two likelihoods had stabilized. Waiting four days for feedback on a structural change to other model components slowed development.

7.2 The optimization — in stages

The optimization was done incrementally, one likelihood at a time, with profiling at each stage. This is the approach I recommend: do not rewrite everything at once. Start with the most expensive component, verify it, then move to the next.

Stage 1: Match segments. Both stages of the Stan implementation — gathering parameters and evaluating the log-probability — were replaced with a single fused C++ function using make_callback_var with TBB parallelism. The forward pass gathers all player and match-level parameters on plain double values, computes the segment log-likelihood, and stores intermediates in a per-segment cache. The reverse pass computes analytic gradients and scatters them back to the parameter adjoints via thread-local buffers.

The pattern, in pseudocode:

Forward (per segment s):
    gather: player ability parameters, league/match covariates
    compute: segment log-likelihood on doubles
    cache: intermediates for reverse pass

Reverse (per segment s):
    compute: analytic gradient w.r.t. each gathered parameter
    scatter: adj_theta[player_idx[s]] += gradient * upstream adjoint

Stage 2: Manager decisions. The same pattern was applied — gather player parameters for each decision point, compute the log-probability on doubles, scatter analytic gradients back. A second .hpp file was created, and a main custom_functions.hpp header was used to #include both implementations.

7.3 Profiling results

All three versions — pure Stan, custom C++ for match segments only, and custom C++ for both likelihoods — were profiled on the same data (85 leagues) with identical seeds, using 24 threads on a single chain, with 5 warmup + 5 sampling iterations, and seeds set for inits and Stan. The proportion of time each likelihood consumes is more informative than raw seconds here, since the original model’s four-day wall time makes the absolute scale clear.

Table 1: Profiled time, 10 iterations, 24 threads, single chain.
Likelihood Pure Stan Custom both Speedup
Match segments 19.9 s 64.4% 1.9 s 34.7% 10.6x
Manager decisions 8.4 s 27.2% 0.9 s 16.9% 9.2x
Other likelihoods (5) 2.5 s 8.4% 2.5 s 47.4% 1.0x
Total profiled 30.8 s 5.4 s 5.7x
NoteWhy 5+5 iterations is enough for profiling

Ten iterations (5 warmup + 5 sampling) is far too few for inference, but it is enough to measure the relative cost of model components. Each gradient evaluation exercises every likelihood, and the proportion of time each one consumes stabilizes quickly — it does not depend on how far along the sampler is in adaptation.

You cannot, however, multiply these seconds by the number of iterations in a production run to estimate total wall time. During warmup, NUTS adapts the step size and mass matrix, which changes the number of leapfrog steps per iteration — often dramatically. Early warmup iterations may build small trajectory trees (cheap), while post-adaptation iterations build larger trees (expensive). The cost per iteration is not constant across phases. What is roughly constant is the proportion each component contributes to a single gradient evaluation, which is why the proportions in Table 1 are informative even from a short run.

The proportion shift tells the story. In pure Stan, the two gather-scatter likelihoods consumed 91.6% of profiled time — match segments alone took nearly two-thirds. The custom C++ eliminated most of that overhead: the two likelihoods dropped from 28.3 seconds to 2.8 seconds combined, a 10x reduction.

The five remaining likelihoods were unchanged. They were never dominated by graph overhead, so custom C++ wouldn’t cut the tape. But with the dominant likelihoods out of the way, these untouched components now account for nearly half the profiled time — not because they got slower, but because everything around them got faster.

The practical effect: a model that previously required over four days per run became feasible for iterative development (or production pipeline automation).

7.4 What made these likelihoods good candidates

Both likelihoods shared the same structural properties:

  • Repetitive gather-scatter pattern: Each observation loops over a set of indexed parameters. The per-observation graph is identical in structure — only the indices and values change. This makes the analytic gradient straightforward to derive once and apply across all observations.
  • Massive observation count: 1.5 million match segments and hundreds of thousands of manager decisions. The per-node overhead, small in isolation, multiplied across this many observations to dominate the computation.
  • Moderate per-observation complexity: Each observation gathers a set of indexed parameters and performs a series of standard operations (e.g., exp, log, weighted sums). The analytic gradient is a page of algebra — work, but not a research problem.
  • Stable likelihood code: Both likelihoods had stabilized in development. The rest of the model was still evolving. This meant the investment in hand-coded C++ would not be wasted on code that was about to change.

Other likelihoods, by contrast, were expensive for different reasons — they involve expensive special functions and fewer but more complex per-observation computations. Graph overhead was not their bottleneck, so custom C++ would not have helped as much.

8 Verifying your gradients

A custom gradient that is wrong in a subtle way is worse than no custom gradient at all. The sampler will run. It will produce numbers. Those numbers will be wrong in ways that are not obvious from diagnostics. You must verify.

8.1 Method 1: Finite-difference checking

The most direct test. The idea is simple: instead of computing the gradient analytically, approximate it by actually nudging each parameter and measuring how the output changes — the same “nudge and measure” reasoning we used to build intuition earlier, but now used as a verification tool.

For each parameter \(\theta_j\), evaluate the function at two nearby points — one slightly above, one slightly below — and compute the slope between them:

\[ \frac{\partial f}{\partial \theta_j} \approx \frac{f(\theta_j + h) - f(\theta_j - h)}{2h} \]

This is called a central difference approximation. The \(2h\) in the denominator is because we moved a total distance of \(2h\) (from \(-h\) to \(+h\)). A good choice for \(h\) is around \(10^{-7}\) — small enough to approximate the true rate of change, large enough to avoid floating-point rounding issues.

A concrete example: suppose your function returns \(f = 3.7\) at the current parameters. You nudge \(\theta_1\) up by \(h = 10^{-7}\) and get \(f = 3.7000002\), then nudge it down by \(h\) and get \(f = 3.6999998\). The approximate gradient is \((3.7000002 - 3.6999998) / (2 \times 10^{-7}) = 2.0\). If your analytic gradient says the adjoint of \(\theta_1\) should be \(2.0\), you are in good shape. If it says \(1.5\), you have a bug.

Compare the finite-difference approximation against your analytic gradient. They should agree to several decimal places. This test requires no full model run and catches most bugs. Run it on a few random parameter vectors to build confidence.

Stan’s testing framework provides stan::test::expect_near_rel (in test/unit/math/expect_near_rel.hpp) for automated finite-difference checking of analytic gradients.

8.2 Method 2: Trajectory matching

This is the strongest end-to-end check. The idea: if two implementations of the same function produce identical values and identical gradients, then NUTS — which is deterministic given a value and gradient function — will make identical sampling decisions. The lp__ trace will match exactly across iterations.

  1. Write two Stan models: one using native Stan code, one using your custom C++ function. They must implement the same mathematical function.
  2. Run both with identical seed, initial values, and data.
  3. Compare the lp__ column from the output.

If the traces match exactly (not approximately — exactly, to full floating-point precision), your gradient is correct. Any discrepancy, even in the last decimal place, indicates a bug.

This works because NUTS makes discrete decisions (accept/reject, tree-building direction) based on the value and gradient. If both are identical, every decision is identical, and the entire trajectory is reproduced.

Note

Trajectory matching relies on NUTS being deterministic given value and gradient. This holds for Stan’s NUTS implementation. It does not necessarily hold for all samplers or for future sampler implementations.

8.3 Method 3: log_prob and grad_log_prob

CmdStanR and CmdStanPy expose methods to evaluate the log-density and its gradient at a specific point in parameter space. This lets you compare the two implementations directly, without running the sampler:

fit_native <- native_model$sample(...)
fit_custom <- custom_model$sample(...)

# Compare at a specific point
theta <- fit_native$draws(format = "draws_matrix")[1, ]
lp_native <- native_model$log_prob(theta)
lp_custom <- custom_model$log_prob(theta)

grad_native <- native_model$grad_log_prob(theta)
grad_custom <- custom_model$grad_log_prob(theta)

# These should be identical (or very nearly so)
max(abs(grad_native - grad_custom))
# After fitting, use log_prob / grad_log_prob methods
# to compare at specific parameter values

This complements finite-difference checking by operating at the full-model level rather than the function level.

9 Summary

The recipe for a custom C++ gradient in Stan:

  1. Identify the bottleneck. Profile your model. Find the function where the expression graph overhead dominates.

  2. Derive the gradient analytically. Work out the partial derivative of your function’s output with respect to each input — on paper, before you write any code.

  3. Write the C++ function following the five-step pattern: signature, data shortcut, forward pass on doubles, arena copies, callback with make_callback_var.

  4. Validate. Finite-difference check the gradient. Then run trajectory matching against the native Stan implementation.

  5. Add parallelism if needed. Use reduce_sum in your Stan model with the custom C++ function handling each slice. Resort to raw TBB only if reduce_sum does not fit.

The technique may seem narrow in scope — it applies to specific, expensive functions, not to entire models. But when it applies, the performance gains can be substantial. A 4x speedup means overnight runs finish before lunch. It means you can iterate on model structure instead of waiting for results. It means the hardware you already have goes further.

The code is more work to write, more work to verify, and more work to maintain than native Stan. That is the tradeoff. Make it for the functions that earn it.

10 Appendix: API reference

A quick reference for the Stan Math functions used in this tutorial.

Table 2: Stan Math API functions for custom gradients
Function Purpose
make_callback_var(double val, F&& cb) Create a var with value val and custom reverse-pass callback cb
value_of(const T& x) Extract the double value from a var or Eigen var container
to_arena(const T& x) Copy x to arena memory so it survives until the reverse pass
forward_as<var>(x) Cast a scalar to var (useful when the type is ambiguous)
is_var_v<T> Compile-time constant: true if T is var, false if double
scalar_type_t<T> Type trait: extracts the scalar type (var or double) from an Eigen type
check_matching_sizes(...) Runtime check that two containers have equal size

References and further reading

Official documentation

Stan Math source code

Stan Discourse