JAX Programming

Master JAX lax.scan: 5 Tips for Multi-Layer RNNs 2025

Unlock high-performance multi-layer RNNs in JAX. Master lax.scan with 5 expert tips for 2025, covering state management, vmap, PyTrees, and debugging.

D

Dr. Alexey Volkov

Principal ML Scientist specializing in high-performance computing and JAX/Flax model optimization.

7 min read7 views

Introduction: Why JAX for RNNs?

Welcome to 2025, where high-performance machine learning is no longer a luxury but a necessity. JAX, with its powerful trifecta of `jit` (just-in-time compilation), `grad` (automatic differentiation), and `vmap` (vectorization), has firmly established itself as the go-to library for performance-critical ML research and production. While JAX excels at parallelizable tasks like Transformers, implementing stateful models like Recurrent Neural Networks (RNNs) can seem daunting. Standard Python `for` loops are a performance killer inside JIT-compiled functions, breaking the very paradigm that makes JAX fast. So, how do we efficiently manage sequences and state?

The answer lies in `jax.lax.scan`, a functional programming primitive that is to JAX what `for` loops are to Python. Mastering `lax.scan` is the key to unlocking the full potential of JAX for any sequential processing task, especially for complex architectures like multi-layer RNNs.

The Power of `lax.scan` for Sequential Models

At its core, `lax.scan` is a higher-order function designed to loop over a sequence of data while carrying a state from one step to the next. It has the signature `(carry, x) -> (carry, y)`, where `carry` is the state (like an RNN's hidden state), `x` is an element from the input sequence (like a token embedding), and `y` is the output for that step. JAX can unroll and compile this entire loop into highly optimized XLA code, making it dramatically faster than a native Python loop.

In this guide, we'll go beyond the basics and dive into five expert-level tips that will help you build and manage sophisticated, multi-layer RNNs using `lax.scan` like a pro.

Tip 1: Structure Your Carry State with PyTrees

The `scan` Carry Pattern

The `carry` is the heart of your RNN's state. For a simple, single-layer RNN, the carry is straightforward: it's just the hidden state vector `h_t`. However, for a multi-layer RNN, you need to manage the hidden states for all layers. A common mistake is to use a Python list of arrays.

PyTrees for Multi-Layer State

JAX's transformations (`jit`, `vmap`, `grad`) are designed to work seamlessly with PyTrees—nested structures of tuples, lists, and dicts. For a multi-layer RNN, the most idiomatic and efficient way to structure your carry state is as a tuple of hidden state arrays. Each element in the tuple corresponds to a layer's hidden state.

Why is this better? PyTrees allow JAX to understand the structure of your data, applying operations to each 'leaf' (the arrays) transparently. This is crucial for vectorization and differentiation.

Consider a 3-layer RNN. Your carry state should be a tuple of three arrays, not a list:

# Good: A PyTree (tuple) of states
carry_state = (h_layer1, h_layer2, h_layer3)

# Bad: A Python list of states (less JAX-friendly)
carry_state_list = [h_layer1, h_layer2, h_layer3]

By using a tuple, you make your state compatible with powerful functions like `vmap`, which we'll explore next.

Tip 2: Vectorize Your Layers with `vmap`

How do you process an input through multiple RNN layers within a single `lax.scan` step? The naive approach is to use another Python `for` loop inside your scan function. This is a major anti-pattern that will either fail to compile or perform poorly.

The Idiomatic 'Scan over Vmap' Approach

The correct, high-performance solution is to combine `lax.scan` with `jax.vmap`. The pattern is: `scan` over the time steps, and `vmap` over the network layers.

First, define a function that performs the computation for a single RNN layer. This function will take the input from the previous layer and the layer's own hidden state to produce the output and the new hidden state.

import jax
import jax.numpy as jnp

def single_layer_rnn_step(params, layer_input, h_prev):
  """Computes one step for a single RNN layer."""
  W_h, W_x, b = params
  h_new = jnp.tanh(jnp.dot(h_prev, W_h) + jnp.dot(layer_input, W_x) + b)
  output = h_new # In a simple RNN, output is the hidden state
  return h_new, output

Next, within your main `scan` body, you use `jax.vmap` to apply this `single_layer_rnn_step` function across all your layers simultaneously. The input to the first layer is the token embedding, and for subsequent layers, it's the output of the layer below it.

This vectorization is what gives JAX its edge. Instead of a slow, sequential iteration over layers, `vmap` allows them to be processed in parallel where possible on hardware like GPUs and TPUs.

Comparison of RNN Implementation Approaches

RNN Layer Processing: `for` loop vs. `scan` vs. `scan` + `vmap`
ApproachPerformanceJIT-CompatibilityReadabilityIdiomatic JAX
Python `for` loop over timePoorNo (static unrolling only)HighNo
`lax.scan` with inner Python `for` loop over layersSub-optimalNo (inner loop is slow)MediumNo
`lax.scan` over time, `vmap` over layersExcellentYesMedium-HighYes

Tip 3: Align PyTrees for Parameters and States

Aligning Parameters and States

To make the `vmap` strategy work, your data must be structured correctly. Just as you structured your multi-layer hidden states as a PyTree (a tuple of arrays), you must do the same for your layer parameters (weights and biases).

If you have `N` layers, your parameters should be a PyTree where each leaf contains a stack of `N` corresponding parameter arrays. For example, if each layer has a `W_h` matrix, your `params` PyTree should contain a single array `all_W_h` of shape `(N, hidden_dim, hidden_dim)`.

# Example parameter PyTree for a 3-layer RNN
# Each leaf is a stack of parameters for all layers.
params = {
  'W_h': jnp.stack([W_h1, W_h2, W_h3]), # Shape: (3, hidden_dim, hidden_dim)
  'W_x': jnp.stack([W_x1, W_x2, W_x3]), # Shape: (3, input_dim, hidden_dim)
  'b':   jnp.stack([b1, b2, b3]),       # Shape: (3, hidden_dim)
}

Configuring `vmap`'s `in_axes`

The magic happens in `vmap`'s `in_axes` argument. It tells `vmap` how to map over your inputs. For our multi-layer RNN step, we want to map over the first axis (the layer dimension) of our parameters and our hidden states, while broadcasting the same input to all layers.

This is how you would use `vmap` inside your scan function:

def scan_fn(carry, token_embedding):
  # carry is a tuple of hidden states for all layers
  # token_embedding is the input for this time step

  # We'll need a loop-like structure for the layers here.
  # The input to layer `i` is the output of layer `i-1`.
  # This cannot be trivially vmapped. A better way is a second, inner scan.

  def layer_fn(layer_carry, _):
    # This inner function processes one layer
    prev_layer_output, prev_layer_h_state = layer_carry
    # ... logic for one layer ...
    return (new_output, new_h_state), None

  # Let's refine the approach. A pure vmap is not for sequential layers.
  # The best pattern is an inner scan over layers.

  # Corrected approach: Inner scan over layers
  def layer_scan_fn(layer_input, layer_params_and_state):
    params, h_prev = layer_params_and_state
    h_new, output = single_layer_rnn_step(params, layer_input, h_prev)
    return output, (h_new, output)

  # Initial input for the first layer is the token embedding
  initial_layer_input = token_embedding
  # Unzip params and states for the inner scan
  layer_params = jax.tree_util.tree_map(lambda x: x, params) # A bit complex
  # The most robust way is a simple Python loop that is JIT-friendly when small.
  # Or an inner lax.scan.

  # Let's use an inner scan for purity.
  h_states = carry
  x = token_embedding
  new_h_states = []

  def inner_scan_body(x_in, params_and_h_state):
    layer_params, h_prev = params_and_h_state
    h_next, x_out = single_layer_rnn_step(layer_params, x_in, h_prev)
    return x_out, h_next

  # `params` should be structured as a tuple of per-layer param dicts.
  # `h_states` is a tuple of per-layer states.
  final_output, new_h_states_tuple = jax.lax.scan(inner_scan_body, x, (params, h_states))

  return new_h_states_tuple, final_output

Correction & Refinement: While the 'scan over vmap' pattern is powerful, it applies when operations are parallel. For stacked RNN layers, the output of layer `i` is the input to layer `i+1`, which is inherently sequential. The truly idiomatic JAX pattern is a nested `lax.scan`: an outer scan over time and an inner scan over layers. The `vmap` tip is more applicable when, for example, running multiple independent RNNs in a batch.

Tip 4: Get Your Initial Carry Right

One of the most common sources of errors when using `lax.scan` is a mismatch between the structure and shape of the initial carry and the carry returned by the scan body function. JAX performs static shape analysis, and if these don't align perfectly, you'll get a cryptic `ConcretizationTypeError` or shape error.

Shape is Everything

Your initial carry must be a PyTree with the exact same structure and `dtype` as the one your function returns. For our multi-layer RNN, this means you need to initialize a tuple of zero-filled arrays, one for each layer's hidden state. Remember to include the batch dimension!

def initialize_carry(batch_size, num_layers, hidden_dim):
  """Initializes the carry state for a multi-layer RNN."""
  # Create a list of zero arrays, one for each layer
  initial_states = [
    jnp.zeros((batch_size, hidden_dim)) for _ in range(num_layers)
  ]
  # Return as a tuple to form a proper PyTree
  return tuple(initial_states)

# Usage
batch_size = 32
num_layers = 4
hidden_dim = 256
initial_carry = initialize_carry(batch_size, num_layers, hidden_dim)

# This initial_carry can now be passed to lax.scan

Always double-check that the PyTree structure (`tuple` of arrays) and the shapes `(batch_size, hidden_dim)` are consistent between your initialization and your scan function's return value.

Tip 5: Debug `lax.scan` with `jax.debug.print`

The Black Box Problem

Because `lax.scan` is JIT-compiled, you can't just drop a standard Python `print()` statement inside it to see what's going on. The code is transformed into an XLA graph before it's executed, and the `print()` statement is often executed only once during tracing, not during the actual loop execution.

Shedding Light with `jax.debug.print`

The solution is `jax.debug.print`. This is a JAX-native printing function that is designed to work inside JIT-compiled code. It guarantees that the values will be printed at execution time, not trace time.

You can use it to inspect the shapes and values of your carry or outputs at each step of the scan, which is invaluable for debugging shape mismatches or exploding/vanishing gradients.

def scan_fn_with_debug(carry, x):
  # ... your RNN logic ...
  new_carry, y = your_rnn_step(carry, x)

  # Print the shape of the first layer's hidden state at each step
  jax.debug.print("Step carry shape[0]: {shape}", shape=new_carry[0].shape)

  return new_carry, y

# When you run the scan, you'll see the printout for each time step.
final_carry, outputs = jax.lax.scan(scan_fn_with_debug, initial_carry, input_sequence)

Using `jax.debug.print` can save you hours of frustration by making the behavior of your compiled loops transparent.

Conclusion: Scanning Your Way to SOTA Performance

Building efficient, multi-layer RNNs in JAX is a matter of embracing its functional programming paradigm. By moving away from Pythonic loops and adopting `jax.lax.scan`, you unlock performance that is simply unattainable otherwise. The key is to think in terms of PyTrees and transformations. By structuring your states and parameters as PyTrees, using nested scans for sequential dependencies, and leveraging `jax.debug.print` for sanity checks, you can build complex and powerful sequence models that are both elegant and blazingly fast. These five tips provide a robust foundation for anyone looking to master sequential modeling in the JAX ecosystem of 2025 and beyond.