Extending the State Space Approach with Mixed Effects

Thomas P. Vladeck

2020-04-16

See this post first. It outlines the bulk of the derivation of this approach, which is just extended to simultaneously model many geographies here.

If you follow Kevin Systrom’s code (this points to a specific commit), it appears he’s using a pretty cool Gaussian process approach to model a changing \(Rt\) 👏

I heard you liked mixed effects models so we put a mixed effects model in your state space model

We are going to leverage the fact that the Kalman Filter was built to handle multidimensional observations. So, instead of modeling just one time series, we’re going to model all the time series at once.

The so-called “observation equation” of the Kalman Filter is

\[\theta_t = Z_t a_t\]

In the previous post, we were modeling \(\theta_t\) as a single value. Here, we’re going to model it as a separate value for every state (along with a nationwide average).

The \(a_t\) are going to be our overall and state-level effects. The trick is going to be in how we set up \(Z_t\) so that we can identify the model.

Let’s jump into the code. First we get our environment set up:

We need snakecase to make the state names easier to work with, and tictoc to time some of our code.

library(tidyverse)
library(KFAS)
library(zoo)
library(snakecase)
library(tictoc)
url = 'https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-states.csv'
dat = read_csv(url)

WINDOW = 20
SERIAL_INTERVAL = 4
GAMMA = 1 / SERIAL_INTERVAL
STATES = dat$state %>% unique
DIM = length(STATES)

All this does is to reconstruct the cumulative number of cases within the desired WINDOW for each of the states. The code is a bit concise and inscrutable, but it works.

dat_multivar = 
  dat %>% 
  filter(state %in% STATES) %>% 
  select(date, state, cases) %>% 
  spread(state, cases) %>% 
  setNames(to_snake_case(colnames(.))) %>% 
  filter(date > lubridate::ymd("2020-03-01")) %>% 
  mutate_at(vars(-date), ~ ifelse(is.na(.x), 0, .x)) %>% 
  mutate_at(vars(-date), function(x) {
    diff(x) %>% 
      {. + 1} %>% 
      {c(rep(0, WINDOW), .)} %>% 
      rollsum(., WINDOW) 
  }) %>% 
  .[-1, ]

As before, itp1 if \(I_{t + 1}\) and it is \(I_t\). The only difference is that these are matrix-valued, not vectors

itp1 = as.matrix(dat_multivar[-1, 2:ncol(dat_multivar)])
it = as.matrix(dat_multivar[-nrow(dat_multivar), 2:ncol(dat_multivar)])

And here is where we construct Z:

observation_matrix = model.matrix(
  ~ 1 + f,
  data = data.frame(f = factor(1:DIM)),
  contrasts = list(f = "contr.sum")
)

Zooming in, this looks like:

observation_matrix[c(1, DIM-1, DIM), c(1,2, DIM)]
##    (Intercept) f1 f55
## 1            1  1   0
## 55           1  0   1
## 56           1 -1  -1

This would correspond to the following:

\[\begin{pmatrix}\theta_1 \\ \theta_2 \\ \theta_3 \end{pmatrix} = \begin{pmatrix}1 & 1 & 0 \\ 1 & 0 & 1 \\ 1 & -1 & -1 \end{pmatrix} * \begin{pmatrix}a_1 \\ a_2 \\ a_3 \end{pmatrix}\]

But it’s easier to think of this as

\[\begin{pmatrix}New\ York \\ Washington \\ Lousiana \end{pmatrix} = \begin{pmatrix}1 & 1 & 0 \\ 1 & 0 & 1 \\ 1 & -1 & -1 \end{pmatrix} * \begin{pmatrix}overall\ average \\ state\ effect\ 1 \\ state\ effect\ 2 \end{pmatrix}\]

Now, we’re ready to build the state space model.

This seems really complex, and… it is. These models have a lot of terms and usually most of them are not used. A few bear noting here. Z is our observation matrix, which we defined above. P1 and P1inf have to do with your initial uncertainty. I followed the procedure in section 6.4 of this vignette. The “fixed effects” (just an intercept) get a diffuse initialization, whereas the “random effects” (the state-level effects) get an exact initialization. T, R, and n you can safely ignore. Q corresponds to the “process variance” in our intercept and state level effects. We model it explicitly below.

mod_multivar = SSModel(
  itp1 ~ -1 + SSMcustom(
    Z = observation_matrix,
    T = diag(DIM),
    R = diag(DIM),
    a1 = rep(0, DIM),
    P1 = diag(c(0, rep(1, DIM-1))),
    P1inf = diag(c(1, rep(0, DIM-1))),
    Q = diag(DIM),
    n = nrow(itp1)
  ),
  u = it,
  distribution = "poisson"
)




More complex models in KFAS require you to specify an “update function”. Basically the fitSSM function below is going to be optimizing pars, and on each optimization loop, it’s going to use update_fn to update the model, and then compute the log likelihood. Here, we are just replacing the diagonal of Q, which are our process variances for the overall and state level effects

update_fn = function(pars, mod) {
  QQ = diag(exp(pars[1:DIM]))
  mod$Q[, , 1] = QQ
  
  mod
}




We fit the model here. It takes about 13 minutes to fit, so this is just me saving the run and reloading it.

# mod_multivar_fit = fitSSM(mod_multivar, rep(-5, DIM+1), update_fn, method = "BFGS")
# saveRDS(mod_multivar_fit, file = "mod_multivar_fit.rds")
mod_multivar_fit = readRDS("mod_multivar_fit.rds")



Once the process variances have been estimated, we can run the filter and smoother to estimate the hidden states.

mod_multivar_filtered = KFS(mod_multivar_fit$model, c("state", "mean"), c("state", "mean"))



Basically what’s happening here is that we’re looping through the timesteps to estimate \(\theta\) for each state. However, we want to add an additional row to our observation matrix (Z_augment) so that we can calculate the overall average

Z_augment = mod_multivar_filtered$model$Z[, ,1] %>% 
  rbind("average" = c(1, rep(0, DIM-1)))

theta = map(
  1:nrow(it), 
  ~ t(Z_augment %*% mod_multivar_filtered$alphahat[.x, ])
) %>% 
  reduce(rbind) %>% 
  as.data.frame 



If you thought computing the averages was nasty, let’s talk about the standard errors. This code loops through each state and figures out which elements of the error covariance matrix need to be added to that state’s variance

compute_se_from_indices = function(var_index) {
  z_index = which(Z_augment[var_index, ] != 0)
  
  series = mod_multivar_filtered$P[z_index, z_index, ]
  if(is.null(dim(series))) return(sqrt(series))
  
  sqrt(apply(series, 3, sum))
}

ses = 1:(DIM+1) %>% 
  map(~ compute_se_from_indices(.x)) %>% 
  cbind.data.frame %>% 
  setNames(colnames(theta)) %>% 
  .[-1, ]

Finally, we can compute the upper and lower bounds of \(\theta\), solve for \(R_t\), and then map/reduce this into a nice, tidy dataframe of \(R_t\) by state, timestep, and with upper and lower bounds.

theta_upper = theta + 1.96 * ses
theta_lower = theta - 1.96 * ses

rt = theta/GAMMA + 1
rt_lower = theta_lower/GAMMA + 1
rt_upper = theta_upper/GAMMA + 1


rts_by_state = 
  list(rt, rt_lower, rt_upper) %>% 
  map2(
    c("mean", "lower", "upper"),
    ~ mutate(.x, date = dat_multivar$date[-1]) %>% 
        gather(-date, key = state, value = !!.y)
  ) %>% 
  reduce(~ left_join(.x, .y, by = c("date", "state")))

And voilà we can now plot our estimate of \(R_t\), along with associated uncertainty, for every state, over time:

rts_by_state %>% 
  ggplot() + 
  aes(x = date, y = mean, ymax = upper, ymin = lower) + 
  geom_line(color = "grey") + 
  geom_ribbon(alpha = 0.5) + 
  facet_wrap(~ state) + 
  geom_hline(yintercept = 1) +
  coord_cartesian(ylim = c(-1, 5)) + 
  scale_y_continuous("", labels = NULL)  + 
  scale_x_date("", labels = NULL)

Thoughts, questions?