Published on

A friendly introduction to scan with JAX

  • avatar
    Nelson Tang

Learning scan with jax and numpyro

Learning scan - it's all over the numpyro time series examples so if I want to implement time series in numpyro from scratch, I'll need to better understand it.

In the context of time series models, scan is helpful because it allows us to iterate over sequences (without a cumbersome for loop) while maintaining state across each iteration.

Elements of a scan

First, let's break down an example scan function into its' components:


Trivial Example - Cumulative Sum

We'll start with the cumsum example from From The Differential Programming Tutorial - Loopy Carry, by Eric Ma. We can implement a cumsum with a for loop like:

from jax import random
import jax.numpy as jnp

# xs
xs = jnp.array([1, 2, 3, 5, 7, 11, 13, 17])

result = []
cumsum = 0
for el in xs:
    cumsum += el
Array([ 1,  3,  6, 11, 18, 29, 42, 59], dtype=int32)

Now, instead of a loop, we rewrite it with jax.lax.scan:

from jax import lax

def cumsum(cumsum, el):
        The result from the previous loop, aka the 'carried over' result
        The current element in the sequence in `xs`
    cumsum += el
    return cumsum, cumsum # ('carryover', 'accumulated')

init_sum = 0
final, result = lax.scan(cumsum, init_sum, xs)
[ 1  3  6 11 18 29 42 59]

We see that lax.scan returned two things:

  • The first item is the latest value of the carried value cumsum
  • The second item being returned is an array where each iteration appended the latest value of cumsum to the end.

To make it a little more concrete, here's a GIF that steps through each iteration:


See that the scanned function returns two items: the first goes back into the loop while the second is appended to an array that acts as a 'running tally' of values from each iteration. You can also pass None here if you don't need it.

Trivial Example 2: A little more of the scan function with compounding interest

The first two parameters in the input function can have arbitrary names but their order matters. The first item is the 'carried over value', aka the carry. The second item is the current element in the xs input array.

Similarly the input function must have two outputs. The first output becomes the new 'carried over value' that will feed back into the input function, while the second output gets appended to an output jax array.

Let's have a concrete example in which we calculate compounded interest for a $100 investment with a 4% interest rate:

# Define the input function to apply at each time step
def interest_growth(value, x):
    value = value * 1.04
    return value, value # ('carried over value', 'accumulated array')

# Define the initial state and input sequence
init_state = 100.0
T = jnp.array([1, 2, 3, 4, 5])
T = jnp.arange(1, 30)

# Apply the scan function
final_state, results = lax.scan(f=interest_growth, init=init_state, xs=T)

[104.       108.159996 112.48639  116.98584  121.66527  126.531876
 131.59314  136.85686  142.33113  148.02437  153.94534  160.10315
 166.50726  173.16754  180.09424  187.298    194.78992  202.58151
 210.68477  219.11215  227.87663  236.99168  246.47134  256.3302
 266.5834   277.24673  288.33658  299.87003  311.8648  ]

And we see the effect of the interest compounding over time.

Trivial Example 3: scan in numpyro to sample from a distribution

And what I'm really interested in is learning scan in the context of doing time series in numpyro.

numpyro has it's own (experimental) implementation of scan which allows us to iteratively sample from numpyro primitives like numpyro.distributions.

Here's a simple example of a random walk:

yt+1=yt+ϵ,ϵN(0,1)y_{t+1} = y_t + \epsilon, \quad \epsilon \sim \mathcal{N}(0,1)

Or, equivalently:

yt=μtμt+1=μt+ϵμ,ϵμN(0,1)\begin{aligned} y_t &= \mu_t\\ \mu_{t+1} &= \mu_t + \epsilon_\mu, \quad \epsilon_\mu \sim \mathcal{N}(0, 1)\\ \end{aligned}

We need to use scan because each subsequent value of μt+1\mu_{t+1} depends on its previous value μt\mu_{t}.

To do this within a scan we'll also need to introduce an effect handler: numpyro.handlers.seed and run the scan function within that context.

# Import libraries
import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
import matplotlib.pyplot as plt

# Define our random walk function as an input into `scan`
def random_walk(value, x):
    value = numpyro.sample('x', dist.Normal(value, 1.))
    # Which is the same as if we were very explicit:
    # value = value + numpyro.sample('x', dist.Normal(0, 1.))
    return value, value

init_0 = 0
T = jnp.arange(100)

# Need the `seed` effect handler in a context manager because of
# the need for a PRNG key within `numpyro.sample`
with numpyro.handlers.seed(rng_seed=42):
    final_location, locations = scan(random_walk, init=init_0, xs=T)

plt.plot(T, locations)
plt.title('Simulated Gaussian Random Walk');

Aside: What if I didn't use numpyro.handlers.seed there?

Note the addition of numpyro.handlers.seed. This is required due to how numpyro requires an explicit PRNG key when you sample from a distribution, so this seed effect handler will automatically do the splits on the key within each loop of the scan function.

Otherwise, if you set up an explicit key with scan this is what you'd get:

# What happens if you don't use `numpyro.handlers.seed`?
prng_key = random.PRNGKey(42)

def random_walk_nokey(value, x):
    # Being really explicit here
    value = value + numpyro.sample('x', dist.Normal(0, 1.), rng_key=prng_key)
    return value, value

init_0 = 0
n_steps = 100

state, final_result = scan(random_walk_nokey, init=init_0, xs=jnp.arange(n_steps))

plt.plot(jnp.arange(n_steps), final_result);

So that's why you need to use the handler!