How to Scan RNN Layers & Slices in JAX: 2025 Tutorial
Master `lax.scan` in JAX for building efficient Recurrent Neural Networks (RNNs). This 2025 tutorial covers scanning layers, handling slices, and a Flax example.
Dr. Alistair Finch
A computational scientist specializing in high-performance machine learning frameworks like JAX and PyTorch.
Introduction: Why JAX and RNNs Need a Special Approach
Welcome to 2025, where JAX has solidified its place as a cornerstone of high-performance machine learning research and production. Its power lies in its functional programming paradigm, enabling features like Just-In-Time (JIT) compilation, automatic differentiation, and vectorization (`jit`, `grad`, `vmap`). However, this functional, stateless approach presents a unique challenge for inherently stateful architectures like Recurrent Neural Networks (RNNs).
Unlike a simple feed-forward network, an RNN maintains a hidden state that evolves over a sequence. In traditional frameworks, this is often handled with a `for` loop and in-place state updates. In JAX, such loops break the purity of functions and cannot be effectively JIT-compiled. So, how do we process sequences efficiently? The answer is `jax.lax.scan`.
This tutorial will guide you through mastering `lax.scan` to build efficient, compilable, and powerful RNNs in JAX. We'll cover everything from the basic cell to scanning entire layers and handling input slices, culminating in a practical example using the popular Flax library.
What is `lax.scan` and Why Do We Need It?
At its core, `lax.scan` is JAX's canonical replacement for loops that carry state from one iteration to the next. It's a functional primitive that expresses a sequential computation in a way that JAX's compiler can understand, optimize, and accelerate.
The Problem with Standard Loops in JAX
Consider a standard Python `for` loop. When JAX's JIT compiler encounters it, it doesn't see a high-level "loop" operation. Instead, it unrolls the loop, creating a separate computational graph for each iteration. For short loops, this is fine. For an RNN processing a sequence of 1000 tokens, this is a disaster, leading to massive compilation times and memory usage.
lax.scan
solves this by defining the entire looping computation as a single, optimizable operation.
The `lax.scan` Signature Explained
The function signature looks like this: jax.lax.scan(f, init, xs)
. Let's break it down in the context of an RNN:
- `f` (the scan function): This is the workhorse. It's a function that defines the logic for a single step of the loop. For an RNN, this is the cell update function. It must have the signature
(carry, x) -> (new_carry, y)
. - `init` (the initial carry): This is the initial state that gets passed to the first iteration of `f`. For an RNN, this is the initial hidden state, `h_0`.
- `xs` (the inputs): This is a JAX array containing the sequence of inputs to be iterated over. For an RNN, this is your input sequence, typically with shape
(time_steps, features)
. `lax.scan` automatically slices this array along the first axis for each step.
The function returns a tuple (final_carry, ys)
, where final_carry
is the final hidden state after the last time step, and ys
is a stacked array of all the per-step outputs `y`.
Anatomy of a JAX RNN Cell
To use `lax.scan`, we first need to define our single-step function `f`.
The RNN Update Rule: A Quick Refresher
A simple RNN cell updates its hidden state `h` at each time step `t` based on the previous hidden state `h_{t-1}` and the current input `x_t`. The formula is:
h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b)
Where `W_hh` and `W_xh` are weight matrices and `b` is a bias vector. The output at this step is often just the new hidden state, so `y_t = h_t`.
Implementing the Cell Function in JAX
Let's translate this into a JAX function that matches the required (carry, x) -> (new_carry, y)
signature. Here, `carry` is the hidden state `h_{t-1}` and `x` is the input `x_t`.
import jax
import jax.numpy as jnp
def rnn_cell_step(params, carry, x):
"""A single step of the RNN cell."""
h_prev = carry
h_new = jnp.tanh(
jnp.dot(params['W_hh'], h_prev) +
jnp.dot(params['W_xh'], x) +
params['b']
)
# For a simple RNN, the output is the new hidden state
return h_new, h_new
Notice how this function is pure: it takes state (`carry`) and parameters (`params`) as explicit inputs and returns the new state. It doesn't modify anything in place.
Scanning an Entire RNN Layer with `lax.scan`
Now we can combine our cell function with `lax.scan` to process a full sequence. We'll need to initialize parameters, a dummy input sequence, and an initial hidden state.
from functools import partial
key = jax.random.PRNGKey(0)
# Define dimensions
input_features = 16
hidden_features = 32
sequence_length = 10
# 1. Initialize parameters
key, w_hh_key, w_xh_key = jax.random.split(key, 3)
params = {
'W_hh': jax.random.normal(w_hh_key, (hidden_features, hidden_features)),
'W_xh': jax.random.normal(w_xh_key, (hidden_features, input_features)),
'b': jnp.zeros(hidden_features)
}
# 2. Create a dummy input sequence and initial hidden state
key, input_key = jax.random.split(key)
input_sequence = jax.random.normal(input_key, (sequence_length, input_features))
initial_hidden_state = jnp.zeros(hidden_features)
# 3. Bind the parameters to the step function
# The scan function `f` must only take `(carry, x)`
scan_fn = partial(rnn_cell_step, params)
# 4. Run the scan!
final_hidden_state, all_outputs = jax.lax.scan(
scan_fn,
initial_hidden_state,
input_sequence
)
print("Input sequence shape:", input_sequence.shape)
print("Final hidden state shape:", final_hidden_state.shape)
print("All outputs shape:", all_outputs.shape)
Expected Output:
Input sequence shape: (10, 16)
Final hidden state shape: (32,)
All outputs shape: (10, 32)
As you can see, `lax.scan` correctly iterated over the 10 time steps, threaded the hidden state through each step, and collected the 10 outputs into a single array.
Understanding Input Slicing in `lax.scan`
A crucial and elegant feature of `lax.scan` is how it handles the input array `xs`. It implicitly slices `xs` along its leading axis (axis 0). In our example, `input_sequence` had a shape of (10, 16)
. For each of the 10 iterations, `lax.scan` took one slice of shape (16,)
and passed it as the `x` argument to our `rnn_cell_step` function.
This is incredibly convenient. You don't need to manually index your input array inside the loop. Just provide the full sequence, and `scan` handles the slicing. This behavior is fundamental to processing sequences, whether they are time steps in an RNN, layers in a deep network, or diffusion steps in a generative model.
Comparison Table: Looping Constructs in JAX
Feature | Python `for` loop | `jax.lax.scan` | `jax.vmap` |
---|---|---|---|
Purpose | General-purpose iteration | Sequential computation with state carry-over | Parallel computation (vectorization) over a batch axis |
State Management | Mutable, in-place updates | Functional, explicit state passing (`carry`) | Stateless; each operation is independent |
JIT-Compatibility | Poor (unrolls the loop, slow compilation) | Excellent (compiles as a single, efficient operation) | Excellent (compiles the function once for batched execution) |
Parallelism | Inherently sequential | Sequential by definition | Inherently parallel |
Typical Use Case | Debugging, simple scripts outside `jit` | RNNs, LSTMs, state-space models, optimizers | Applying a function to a batch of data, data parallelism |
A Production-Ready Approach: Building an RNN with Flax
While writing raw JAX is powerful, for larger models, a library like Flax provides helpful abstractions for managing parameters and defining layers. Flax has its own `nn.scan` wrapper that simplifies this process even further.
Defining a Flax RNN Module
Here's how you can encapsulate our RNN logic into a Flax `Module`.
import flax.linen as nn
class SimpleRNN(nn.Module):
hidden_features: int
@nn.compact
def __call__(self, xs):
# Define a scan-compatible version of the cell within the call
# This allows Flax to manage the parameters ('W_hh', 'W_xh', 'b')
def rnn_cell(carry, x):
h_prev = carry
# Use nn.Dense for a more robust way to define linear layers
h_new = nn.tanh(
nn.Dense(features=self.hidden_features, use_bias=False, name='W_hh')(h_prev) +
nn.Dense(features=self.hidden_features, name='W_xh')(x)
)
return h_new, h_new
# Flax's scan wrapper handles parameter lifting automatically
scan = nn.scan(
rnn_cell,
variable_broadcast='params',
split_rngs={'params': False}
)
# 1. Initialize carry
initial_carry = self.param('initial_carry', nn.initializers.zeros, (self.hidden_features,))
# 2. Run the scan
final_carry, all_outputs = scan(initial_carry, xs)
return all_outputs
Initializing and Running the Flax Layer
Using this Flax module is clean and separates parameter management from the forward pass logic.
# Define model and dummy data
model = SimpleRNN(hidden_features=32)
dummy_inputs = jnp.ones((10, 16)) # (sequence_length, features)
key = jax.random.PRNGKey(42)
# Initialize parameters
params = model.init(key, dummy_inputs)['params']
# Run the forward pass (the 'apply' method)
outputs = model.apply({'params': params}, dummy_inputs)
print("Flax model output shape:", outputs.shape)
Expected Output:
Flax model output shape: (10, 32)
The result is the same, but the code is now more modular, reusable, and aligned with best practices for building complex models in JAX.
Conclusion: Your Next Steps with JAX RNNs
You've now seen how to tame stateful computation in the functional world of JAX. `lax.scan` is not just a tool; it's a fundamental concept for implementing any sequential process efficiently. By understanding its `(carry, x) -> (new_carry, y)` pattern and how it handles input slicing, you've unlocked the ability to build high-performance RNNs, LSTMs, GRUs, and more.
The journey doesn't end here. Try implementing a more complex cell like an LSTM, or explore how `scan` can be nested or combined with `vmap` to process batches of sequences. With the foundation you've built today, you are well-equipped to tackle advanced sequence modeling challenges in JAX.