Machine Learning

How to Scan RNN Layers & Slices in JAX: 2025 Tutorial

Master `lax.scan` in JAX for building efficient Recurrent Neural Networks (RNNs). This 2025 tutorial covers scanning layers, handling slices, and a Flax example.

D

Dr. Alistair Finch

A computational scientist specializing in high-performance machine learning frameworks like JAX and PyTorch.

7 min read4 views

Introduction: Why JAX and RNNs Need a Special Approach

Welcome to 2025, where JAX has solidified its place as a cornerstone of high-performance machine learning research and production. Its power lies in its functional programming paradigm, enabling features like Just-In-Time (JIT) compilation, automatic differentiation, and vectorization (`jit`, `grad`, `vmap`). However, this functional, stateless approach presents a unique challenge for inherently stateful architectures like Recurrent Neural Networks (RNNs).

Unlike a simple feed-forward network, an RNN maintains a hidden state that evolves over a sequence. In traditional frameworks, this is often handled with a `for` loop and in-place state updates. In JAX, such loops break the purity of functions and cannot be effectively JIT-compiled. So, how do we process sequences efficiently? The answer is `jax.lax.scan`.

This tutorial will guide you through mastering `lax.scan` to build efficient, compilable, and powerful RNNs in JAX. We'll cover everything from the basic cell to scanning entire layers and handling input slices, culminating in a practical example using the popular Flax library.

What is `lax.scan` and Why Do We Need It?

At its core, `lax.scan` is JAX's canonical replacement for loops that carry state from one iteration to the next. It's a functional primitive that expresses a sequential computation in a way that JAX's compiler can understand, optimize, and accelerate.

The Problem with Standard Loops in JAX

Consider a standard Python `for` loop. When JAX's JIT compiler encounters it, it doesn't see a high-level "loop" operation. Instead, it unrolls the loop, creating a separate computational graph for each iteration. For short loops, this is fine. For an RNN processing a sequence of 1000 tokens, this is a disaster, leading to massive compilation times and memory usage.

lax.scan solves this by defining the entire looping computation as a single, optimizable operation.

The `lax.scan` Signature Explained

The function signature looks like this: jax.lax.scan(f, init, xs). Let's break it down in the context of an RNN:

  • `f` (the scan function): This is the workhorse. It's a function that defines the logic for a single step of the loop. For an RNN, this is the cell update function. It must have the signature (carry, x) -> (new_carry, y).
  • `init` (the initial carry): This is the initial state that gets passed to the first iteration of `f`. For an RNN, this is the initial hidden state, `h_0`.
  • `xs` (the inputs): This is a JAX array containing the sequence of inputs to be iterated over. For an RNN, this is your input sequence, typically with shape (time_steps, features). `lax.scan` automatically slices this array along the first axis for each step.

The function returns a tuple (final_carry, ys), where final_carry is the final hidden state after the last time step, and ys is a stacked array of all the per-step outputs `y`.

Anatomy of a JAX RNN Cell

To use `lax.scan`, we first need to define our single-step function `f`.

The RNN Update Rule: A Quick Refresher

A simple RNN cell updates its hidden state `h` at each time step `t` based on the previous hidden state `h_{t-1}` and the current input `x_t`. The formula is:

h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b)

Where `W_hh` and `W_xh` are weight matrices and `b` is a bias vector. The output at this step is often just the new hidden state, so `y_t = h_t`.

Implementing the Cell Function in JAX

Let's translate this into a JAX function that matches the required (carry, x) -> (new_carry, y) signature. Here, `carry` is the hidden state `h_{t-1}` and `x` is the input `x_t`.

import jax
import jax.numpy as jnp

def rnn_cell_step(params, carry, x):
  """A single step of the RNN cell."""
  h_prev = carry
  h_new = jnp.tanh(
      jnp.dot(params['W_hh'], h_prev) + 
      jnp.dot(params['W_xh'], x) + 
      params['b']
  )
  # For a simple RNN, the output is the new hidden state
  return h_new, h_new

Notice how this function is pure: it takes state (`carry`) and parameters (`params`) as explicit inputs and returns the new state. It doesn't modify anything in place.

Scanning an Entire RNN Layer with `lax.scan`

Now we can combine our cell function with `lax.scan` to process a full sequence. We'll need to initialize parameters, a dummy input sequence, and an initial hidden state.

from functools import partial

key = jax.random.PRNGKey(0)

# Define dimensions
input_features = 16
hidden_features = 32
sequence_length = 10

# 1. Initialize parameters
key, w_hh_key, w_xh_key = jax.random.split(key, 3)
params = {
    'W_hh': jax.random.normal(w_hh_key, (hidden_features, hidden_features)),
    'W_xh': jax.random.normal(w_xh_key, (hidden_features, input_features)),
    'b': jnp.zeros(hidden_features)
}

# 2. Create a dummy input sequence and initial hidden state
key, input_key = jax.random.split(key)
input_sequence = jax.random.normal(input_key, (sequence_length, input_features))
initial_hidden_state = jnp.zeros(hidden_features)

# 3. Bind the parameters to the step function
# The scan function `f` must only take `(carry, x)`
scan_fn = partial(rnn_cell_step, params)

# 4. Run the scan!
final_hidden_state, all_outputs = jax.lax.scan(
    scan_fn,
    initial_hidden_state,
    input_sequence
)

print("Input sequence shape:", input_sequence.shape)
print("Final hidden state shape:", final_hidden_state.shape)
print("All outputs shape:", all_outputs.shape)

Expected Output:

Input sequence shape: (10, 16)
Final hidden state shape: (32,)
All outputs shape: (10, 32)

As you can see, `lax.scan` correctly iterated over the 10 time steps, threaded the hidden state through each step, and collected the 10 outputs into a single array.

Understanding Input Slicing in `lax.scan`

A crucial and elegant feature of `lax.scan` is how it handles the input array `xs`. It implicitly slices `xs` along its leading axis (axis 0). In our example, `input_sequence` had a shape of (10, 16). For each of the 10 iterations, `lax.scan` took one slice of shape (16,) and passed it as the `x` argument to our `rnn_cell_step` function.

This is incredibly convenient. You don't need to manually index your input array inside the loop. Just provide the full sequence, and `scan` handles the slicing. This behavior is fundamental to processing sequences, whether they are time steps in an RNN, layers in a deep network, or diffusion steps in a generative model.

Comparison Table: Looping Constructs in JAX

Python `for` loop vs. `lax.scan` vs. `vmap`
FeaturePython `for` loop`jax.lax.scan``jax.vmap`
PurposeGeneral-purpose iterationSequential computation with state carry-overParallel computation (vectorization) over a batch axis
State ManagementMutable, in-place updatesFunctional, explicit state passing (`carry`)Stateless; each operation is independent
JIT-CompatibilityPoor (unrolls the loop, slow compilation)Excellent (compiles as a single, efficient operation)Excellent (compiles the function once for batched execution)
ParallelismInherently sequentialSequential by definitionInherently parallel
Typical Use CaseDebugging, simple scripts outside `jit`RNNs, LSTMs, state-space models, optimizersApplying a function to a batch of data, data parallelism

A Production-Ready Approach: Building an RNN with Flax

While writing raw JAX is powerful, for larger models, a library like Flax provides helpful abstractions for managing parameters and defining layers. Flax has its own `nn.scan` wrapper that simplifies this process even further.

Defining a Flax RNN Module

Here's how you can encapsulate our RNN logic into a Flax `Module`.

import flax.linen as nn

class SimpleRNN(nn.Module):
  hidden_features: int

  @nn.compact
  def __call__(self, xs):
    # Define a scan-compatible version of the cell within the call
    # This allows Flax to manage the parameters ('W_hh', 'W_xh', 'b')
    def rnn_cell(carry, x):
      h_prev = carry
      # Use nn.Dense for a more robust way to define linear layers
      h_new = nn.tanh(
          nn.Dense(features=self.hidden_features, use_bias=False, name='W_hh')(h_prev) + 
          nn.Dense(features=self.hidden_features, name='W_xh')(x)
      )
      return h_new, h_new

    # Flax's scan wrapper handles parameter lifting automatically
    scan = nn.scan(
        rnn_cell,
        variable_broadcast='params',
        split_rngs={'params': False}
    )
    
    # 1. Initialize carry
    initial_carry = self.param('initial_carry', nn.initializers.zeros, (self.hidden_features,))

    # 2. Run the scan
    final_carry, all_outputs = scan(initial_carry, xs)
    return all_outputs

Initializing and Running the Flax Layer

Using this Flax module is clean and separates parameter management from the forward pass logic.

# Define model and dummy data
model = SimpleRNN(hidden_features=32)
dummy_inputs = jnp.ones((10, 16)) # (sequence_length, features)
key = jax.random.PRNGKey(42)

# Initialize parameters
params = model.init(key, dummy_inputs)['params']

# Run the forward pass (the 'apply' method)
outputs = model.apply({'params': params}, dummy_inputs)

print("Flax model output shape:", outputs.shape)

Expected Output:

Flax model output shape: (10, 32)

The result is the same, but the code is now more modular, reusable, and aligned with best practices for building complex models in JAX.

Conclusion: Your Next Steps with JAX RNNs

You've now seen how to tame stateful computation in the functional world of JAX. `lax.scan` is not just a tool; it's a fundamental concept for implementing any sequential process efficiently. By understanding its `(carry, x) -> (new_carry, y)` pattern and how it handles input slicing, you've unlocked the ability to build high-performance RNNs, LSTMs, GRUs, and more.

The journey doesn't end here. Try implementing a more complex cell like an LSTM, or explore how `scan` can be nested or combined with `vmap` to process batches of sequences. With the foundation you've built today, you are well-equipped to tackle advanced sequence modeling challenges in JAX.