Machine Learning

JAX lax.scan RNN Fix: No Dynamic Indexing Guide 2025

Struggling with dynamic indexing errors in JAX RNNs? Our 2025 guide provides a definitive fix, explaining how to properly use lax.scan for high-performance models.

D

Dr. Adrian Petrov

Senior ML Research Scientist specializing in high-performance computing and JAX model optimization.

7 min read3 views

Introduction: The JAX & RNN Dilemma

JAX has taken the machine learning world by storm, offering incredible performance through its Just-In-Time (JIT) compilation and automatic differentiation capabilities. For many tasks, it feels like writing NumPy code that runs at GPU/TPU speeds. However, when you venture into the realm of Recurrent Neural Networks (RNNs), you often hit a notorious roadblock: the dreaded `TracerArrayConversionError`. This error typically arises from a fundamental JAX design philosophy—the incompatibility of dynamic operations, like indexing into an array with a variable, inside a JIT-compiled function.

If you've tried to implement an RNN in JAX using a standard Python for loop over your sequence, you've likely encountered this issue. The solution lies in a powerful but initially unintuitive tool: lax.scan. This guide for 2025 will demystify lax.scan, show you exactly how to structure your RNN code to make it JAX-friendly, and fix the "no dynamic indexing" problem for good.

The Root of the Problem: JIT Compilation vs. Dynamic Shapes

To understand why your typical RNN implementation breaks, you need to grasp how @jit works. When JAX's JIT compiler sees your function, it traces its execution with abstract values (tracers) that represent the shapes and dtypes of your arrays, not their actual values. It uses this trace to build a highly optimized, static computation graph for the specific shapes it saw.

The key word here is static. The graph cannot change. An operation like data[i], where i is a variable changing within a loop, is a dynamic operation. The compiler doesn't know the value of i ahead of time, so it cannot determine which slice of data to access and cannot build a fixed graph. JAX raises an error to prevent this ambiguity.

lax.scan is JAX's prescribed way to handle loops that need to be JIT-compiled. It effectively "unrolls" the loop within the computation graph, ensuring that all operations remain static and traceable.

The Common Mistake: Attempting Dynamic Indexing

Let's look at a piece of code that seems logical from a standard Python perspective but is destined to fail in JAX.


import jax
import jax.numpy as jnp

# A naive, incorrect attempt to implement an RNN loop
def naive_rnn_scan(params, inputs):
  h = jnp.zeros(params['h_dim'])
  outputs = []
  # This Python for loop is the problem!
  for i in range(inputs.shape[0]):
    # Dynamic indexing: inputs[i]
    x_t = inputs[i]
    h = jnp.tanh(jnp.dot(x_t, params['W_x']) + jnp.dot(h, params['W_h']))
    outputs.append(h)
  return jnp.stack(outputs)

# This will fail if you try to jit it
# jax.jit(naive_rnn_scan)(params, inputs)
# >>> TracerArrayConversionError:...
        

Why This Code Fails

When jax.jit tries to trace this function, it encounters the Python for loop. It attempts to treat the loop variable i as a concrete Python integer to unroll the loop. However, the operations inside, especially the state update of h, involve JAX tracers. The attempt to append to a standard Python list (outputs.append(h)) and the mix of Python control flow with traced values create a conflict that JAX cannot resolve into a static graph. The core issue is trying to use runtime Python logic to control a compile-time graph construction process.

The Solution: Embracing the lax.scan Pattern

The correct way to implement loops in JAX is to use lax.scan. It forces you into a more functional programming style that is perfectly compatible with JIT compilation.

Understanding the Scan Function

The signature of lax.scan is: jax.lax.scan(f, init, xs)

  • f: The function to be applied at each step. It must have the signature (carry, x) -> (new_carry, y).
  • init: The initial value of the carry. For an RNN, this is your initial hidden state, h_0.
  • xs: The sequence of inputs to scan over. This is your input data, where the first axis is the time or sequence dimension. JAX will automatically slice this array and pass each slice as x to your function f.

The carry is the magic that carries state from one iteration to the next. The output of lax.scan is a tuple (final_carry, ys), where final_carry is the last hidden state and ys is the stacked collection of all the step outputs y.

A Correct RNN Cell with `lax.scan`

Let's rewrite our RNN using this pattern. Notice how we define a `scan_fn` that perfectly matches the required (carry, x) signature.


import jax
import jax.numpy as jnp
from functools import partial

# Define the function for a single step of the RNN
# This function has the signature: (carry, x) -> (new_carry, y)
def rnn_step(params, carry, x_t):
  # carry is the hidden state from the previous step, h_{t-1}
  h_prev = carry
  # x_t is the input for the current step
  
  # The core RNN computation
  h_t = jnp.tanh(jnp.dot(x_t, params['W_x']) + jnp.dot(h_prev, params['W_h']) + params['b'])
  
  # The new carry is the new hidden state, h_t
  # The output y for this step is also h_t
  return h_t, h_t

# The main function that uses lax.scan
@partial(jax.jit, static_argnames=['hidden_dim'])
def jax_rnn_scan(params, inputs, hidden_dim):
  # 1. Define the initial hidden state (the initial carry)
  h0 = jnp.zeros(hidden_dim)
  
  # 2. Bind the static `params` to our step function
  scan_fn = partial(rnn_step, params)
  
  # 3. Run lax.scan
  # init = h0 (initial carry)
  # xs = inputs (the sequence to iterate over)
  final_hidden_state, all_hidden_states = jax.lax.scan(scan_fn, h0, inputs)
  
  return final_hidden_state, all_hidden_states

# Example Usage:
key = jax.random.PRNGKey(0)
seq_len, input_dim, hidden_dim = 10, 5, 8

# Initialize parameters
params = {
    'W_x': jax.random.normal(key, (input_dim, hidden_dim)),
    'W_h': jax.random.normal(key, (hidden_dim, hidden_dim)),
    'b': jnp.zeros(hidden_dim)
}
# Dummy input data
dummy_inputs = jnp.ones((seq_len, input_dim))

# Run the JIT-compiled function
final_h, all_h = jax_rnn_scan(params, dummy_inputs, hidden_dim)

print("Shape of all hidden states:", all_h.shape)
# >>> Shape of all hidden states: (10, 8)
        

This code is fully JIT-compatible. We've expressed the recurrent logic in a way that JAX can trace and optimize into a single, highly efficient computation graph. No dynamic indexing, no Python lists, no problem.

Approach Comparison: Python Loop vs. `lax.scan`

Comparison of RNN Implementation Approaches
Feature Naive Python for Loop JAX lax.scan
JIT Compatibility No. Fails with TracerArrayConversionError. Yes. Designed specifically for JIT compilation.
Performance Very slow. Interpreted Python loop with JAX ops inside. Extremely fast. The entire loop is compiled into a single optimized XLA kernel.
Code Structure Imperative, stateful (e.g., h = new_h). Familiar to Python programmers. Functional. Requires defining a stateless step function (carry, x) -> (new_carry, y).
Parallelism Strictly sequential execution in the Python interpreter. Can be parallelized by the XLA compiler on hardware like GPUs/TPUs.
Primary Use Case Debugging or simple scripting outside of JIT contexts. High-performance, scalable implementations of any sequential process (RNNs, optimizers, etc.).

Advanced Technique: Handling Variable-Length Sequences with Masking

A common challenge in NLP and time-series analysis is that sequences have different lengths. However, lax.scan, like most JAX operations, requires inputs with static, fixed shapes. How do we reconcile this?

The standard solution is padding and masking.

  1. Pad: Pad all sequences in a batch to the length of the longest sequence. The padded values are typically zeros.
  2. Mask: Create a boolean mask for each sequence, with True (or 1) for real data points and False (or 0) for padded points.
  3. Apply Mask: Pass the mask into your lax.scan loop as part of the `xs`. Inside your step function, use the mask to conditionally update the hidden state or calculate the output. This ensures that padded steps do not affect the final result.

Here's how you can modify the step function to use a mask:


def rnn_step_masked(params, carry, x):
  # Unpack the input and the mask for this timestep
  x_t, mask = x
  h_prev = carry

  # Calculate the potential new hidden state
  h_t_candidate = jnp.tanh(jnp.dot(x_t, params['W_x']) + jnp.dot(h_prev, params['W_h']) + params['b'])

  # Use jnp.where to conditionally update the state.
  # If mask is 1 (True), use the new state. If mask is 0 (False), keep the old state.
  # The mask needs to be reshaped to allow broadcasting.
  h_t = jnp.where(mask[:, None], h_t_candidate, h_prev)
  
  # The output is also masked to prevent padded steps from contributing to the loss
  y_t = h_t * mask[:, None]

  return h_t, y_t

# In your main function, you would pack inputs and masks together:
# masks = (jnp.arange(max_len) < true_lengths[:, None])
# packed_xs = (padded_inputs, masks)
# final_h, all_h = jax.lax.scan(scan_fn, h0, packed_xs)
        

This technique is highly efficient and is the idiomatic way to handle variable-length data in JAX.

Conclusion: Thinking Functionally with JAX

The "no dynamic indexing" rule in JAX isn't a bug; it's a feature of its design that enables extreme performance. By forcing you to use constructs like lax.scan, JAX encourages you to write pure, functional code that can be easily analyzed, transformed, and optimized. While the initial learning curve for the (carry, x) pattern can be steep, mastering it unlocks the full potential of JAX for sequential models like RNNs, LSTMs, and Transformers. The key is to shift your thinking from imperative loops to functional scans over data.