Overview

This vignette demonstrates how to simulate from and fit a multi-country SEIR model incorporating within-country transmission as well as importation of cases from neighboring countries to partially observed incidence data using the the stemr package. This model was described and used in Fintzi, et al. (2020) to analyze national level incidence counts from the 2014-2015 outbreak of Ebola in West Africa, though the focus of this vignette will be on fitting the model to simulated data in order to also demonstrate the simulation functionality of the package. The real data is included in the package and is accessible by calling data("ebola") once the package is loaded. The models fit to the West Africa data and simulated data are largely the same, with only slight differences in the prior hyperparameters and the length of the simulated outbreak (70 vs. 73 weeks). Hence, the code used to fit the model to simulated data can also be used to analyze the real-world outbreak data.

Model description

We will simulate from and fit a multi–country model for the spread of Ebola in Guinea, Liberia, and Sierra Leone under country–specific SEIR transmission dynamics, illustrated in the figure below. Cross–border transmission was incorporated via virtual migration of infectious individuals and was parameterized by extrinsic reproduction numbers, interpretable as the expected number of secondary cases in a country per index case in another country. Transmission was assumed to commence in Liberia on March 2\(^{nd}\), 2014, and in Sierra Leone on May 4$ ^{th}$, 2014, corresponding to three weeks prior to the first cases in those countries. The observed incidence in each country was modeled as a negative binomial sample of the true incidence. The total incidence in each country was small relative to the population size, suggesting that only a fraction of the population was geographically or socially linked to ongoing transmission. Hence, we estimated the effective population size in each country, interpreted as the size of the sub–population within which the outbreak occurred.

Diagram of state transitions for a joint model for Ebola transmission in Guinea, Liberia, and Sierra Leone. Dotted boxes denote countries, nodes in circles denote the model compartments: susceptible but removed from infectious contact \((S^R)\), susceptible but exposed to infectious contact \((S^E)\), exposed \((E)\), infectious \((I)\), recovered \((R)\). Compartments are subscripted with country indicators. Solid lines with arrows indicate stochastic transitions between model compartments, which occur continuously in time. Dashed lines indicate that infected individuals in one country contribute to the force of infection in another country. Rates at which individuals transition between compartments are denoted by \(\lambda\) and are subscripted by compartments and superscripted by countries, e.g., \(\lambda_{S^EE}^L\) is the rate at which susceptible individuals become exposed in Liberia. Transmission in Liberia and Sierra Leone was assumed to commence at 10 and 19 weeks, respectively.

Parameters and their interpretations. Subscripts, \(A,B\), indicate countries.
Parameter Interpretation Transition
\(\beta_A(t)\) Per–contact rate of transmission within country \(A\). \(S^{E}_A \rightarrow E_A\)
\(\alpha_{AB}(t)\) Per–contact rate of transmission from country \(A\) to \(B\). \(S^{E}_A \rightarrow E_A\)
\(\omega_A(t)\) Rate at which latent individuals become infectious. \(E_A \rightarrow I_A\)
\(\mu_A(t)\) Rate at which infectious individuals recover. \(I_A \rightarrow R_A\)
\(P_{eff,A}\) Effective population size.
\(\rho_A\) Mean case detection rate.
\(\phi_A\) Negative binomial overdispersion.
Rates of state transition. Subscripts for rates indicate model compartments and superscripts indicate countries, while subscripts for compartments and parameters indicate countries. All rates of state transition for Liberia and Sierra Leone are zero until three weeks prior to the first detected case, when transmission was assumed to commence in each country.
Rate Transition
\(\lambda^{A}_{S^EE}(t) = \beta_A(t)\left (I_A + \alpha_{BA}(t)I_B + \alpha_{CA}(t)I_C\right )S^E_A\) \(S^E_A\rightarrow E_A\)
\(\lambda^{A}_{EI}(t) = \omega_A(t)E_A\) \(E_A\rightarrow I_A\)
\(\lambda^A_{IR}(t) = \mu_A(t)I_A\) \(I_A \rightarrow R_A\)

Simulating from and fitting the model

We initialize parameters and instatiate the SIR model in the code block below.

library(stemr)
set.seed(52787)

# total population sizes
popsize_guin <- 11.8e6
popsize_lib <- 4.4e6
popsize_sln <- 7.1e6

log_popsize_guin <- log(popsize_guin)
log_popsize_lib <- log(popsize_lib)
log_popsize_sln <- log(popsize_sln)

# effective population sizes (i.e., number of susceptibles)
ep_guin <- 2e4
ep_lib <- 3.5e4
ep_sln <- 2.5e4

# initialize model dynamics
# each country is a stratum with SEIR compartments
strata <- c("guin", "lib", "sln")

# compartments given in a list, ALL is a reserved keyword indicating that 
# the compartment is present in all strata (countries). alternately, a character
# vector of strata in which each compartment occurs could be supplied
compartments <- list(S = "ALL", E = "ALL", I = "ALL", R = "ALL")

# rates of state transition - see help('rate') for additional details on specifying 
# rates. note that transmission_lib and transmission_sln are time-varying covariates
# that will be declared below.
rates <- 
      list(
            rate(
                  rate = "beta_guin * (I_guin + transmission_lib * alpha_lib2guin * I_lib + transmission_sln * alpha_sln2guin * I_sln) * S_guin",
                  from = "S",
                  to = "E",
                  strata = "guin",
                  lumped = TRUE,
                  incidence = T
            ),
            rate(
                  "transmission_lib * beta_lib * (I_lib + alpha_guin2lib * I_guin + transmission_sln * alpha_sln2lib * I_sln) * S_lib",
                  from = "S",
                  to = "E",
                  strata = "lib",
                  lumped = TRUE,
                  incidence = T
            ),
            rate(
                  "transmission_sln * beta_sln * (I_sln + alpha_guin2sln * I_guin + transmission_lib * alpha_lib2sln * I_lib) * S_sln",
                  from = "S",
                  to = "E",
                  strata = "sln",
                  lumped = TRUE,
                  incidence = T
            ),
            rate(
                  "omega_guin",
                  from = "E",
                  to = "I",
                  strata = "guin",
                  incidence = T
            ),
            rate(
                  "transmission_lib * omega_lib",
                  from = "E",
                  to = "I",
                  strata = "lib",
                  incidence = T
            ),
            rate(
                  "transmission_sln * omega_sln",
                  from = "E",
                  to = "I",
                  strata = "sln",
                  incidence = T
            ),
            rate(
                  "mu_guin",
                  from = "I",
                  to = "R",
                  strata = "guin",
                  incidence = T
            ),
            rate(
                  "transmission_lib * mu_lib",
                  from = "I",
                  to = "R",
                  strata = "lib",
                  incidence = T
            ),
            rate(
                  "transmission_sln * mu_sln",
                  from = "I",
                  to = "R",
                  strata = "sln",
                  incidence = T
            )
      )

# function for initializing the compartment volumes at time 0
# this is a list of stem_initializer lists, one for each stratum
state_initializer <- 
      list(
            stem_initializer(
                  c(
                        S_guin = ep_guin - 30,
                        E_guin = 15,
                        I_guin = 10,
                        R_guin = 5
                  ),
                  fixed = TRUE,
                  strata = "guin",
                  prior = c(popsize_guin - 30, 15, 10, 5)
            ),
            stem_initializer(
                  c(
                        S_lib = ep_lib - 30,
                        E_lib = 15,
                        I_lib = 10,
                        R_lib = 5
                  ),
                  fixed = TRUE,
                  strata = "lib",
                  prior = c(popsize_lib - 30, 15, 10, 5)
            ),
            stem_initializer(
                  c(
                        S_sln = ep_sln - 30,
                        E_sln = 15,
                        I_sln = 10,
                        R_sln = 5
                  ),
                  fixed = TRUE,
                  strata = "sln",
                  prior = c(popsize_sln - 30, 15, 10, 5)
            )
      )

# we declare time 0 as a constant and declare tmax
constants <- c(t0 = 0)
t0 <- 0; tmax <- 70;

# assume there is no transmission in liberia and sierra leone before weeks 10 and 19.
# we encode this using time-varying covariates. 
tcovar <- cbind(time = 0:tmax,
                transmission_lib = c(rep(0, 9), rep(1, (tmax + 1) - 9)),
                transmission_sln = c(rep(0, 18), rep(1, (tmax + 1) - 18)))

# recovery rates
mu_guin <- 0.9
mu_lib <- 1.1
mu_sln <- 1

# named vector of parameters on their natural scales (i.e., the scale on which they
# enter the rate functions). 
parameters = c(beta_guin = 1.2 / ep_guin * mu_guin, 
               beta_lib = 1.35 / ep_lib * mu_lib,
               beta_sln = 1.45 / ep_sln * mu_sln,
               alpha_guin2lib = 0.02 / ep_lib * mu_guin,
               alpha_guin2sln = 0.02 / ep_sln * mu_guin,
               alpha_lib2guin = 0.02 / ep_guin * mu_lib,
               alpha_lib2sln  = 0.02 / ep_sln * mu_lib,
               alpha_sln2guin = 0.02 / ep_guin * mu_sln,
               alpha_sln2lib  = 0.02 / ep_lib * mu_sln,
               omega_guin = 1.2,
               omega_lib = 1,
               omega_sln = 0.8,
               mu_guin = mu_guin,
               mu_lib = mu_lib,
               mu_sln = mu_sln,
               rho_guin = 100/150,
               rho_lib = 100/175,
               rho_sln = 100/125,
               phi_guin = 100,
               phi_lib = 100,
               phi_sln = 100)

# save the true parameters for later
true_pars <- parameters

# initialize the model dynamics
dynamics <-
      stem_dynamics(
            rates = rates,
            tmax = tmax,
            timestep = NULL,
            parameters = parameters,
            state_initializer = state_initializer,
            compartments = compartments,
            constants = constants,
            strata = strata,
            tcovar = tcovar,
            messages = F,
            compile_ode = T,
            compile_rates = T,
            compile_lna = T,
            rtol = 1e-5,
            atol = 1e-5,
            step_size = 1e-7
      )

# the next step is to initialize the measurement process, see help("emission") for
# additional details on specifying the emission distributions.
# here we have different case detection rates and overdispersions for each country
emissions <-
      list(
            emission(
                  meas_var = "guin_cases",
                  distribution = "negbinomial",
                  emission_params = c("phi_guin", "E_guin2I_guin * rho_guin"),
                  incidence = TRUE,
                  obstimes = seq(1, tmax, by = 1)
            ),
            emission(
                  meas_var = "lib_cases",
                  distribution = "negbinomial",
                  emission_params = c("phi_lib", "E_lib2I_lib * rho_lib"),
                  incidence = TRUE,
                  obstimes = seq(10, tmax, by = 1)
            ),
            emission(
                  meas_var = "sln_cases",
                  distribution = "negbinomial",
                  emission_params = c("phi_sln", "E_sln2I_sln * rho_sln"),
                  incidence = TRUE,
                  obstimes = seq(19, tmax, by = 1)
            )
      )

# compile the measurement process
measurement_process <-
      stem_measure(emissions = emissions,
                   dynamics = dynamics,
                   messages = F)

# put it all together in a stem object - no data since we haven't simulated it yet
stem_object <- 
  make_stem(dynamics = dynamics, 
            measurement_process = measurement_process)

After compiling the model, we simulate an outbreak from the MJP using Gillespie’s direct algorithm as well as the observed incidence.


# simulate the outbreak and dataset - defaults to simulation via the MJP
sim <- simulate_stem(stem_object)

# save the true path of the outbreak and the data
true_path   <- sim$paths[[1]]    # true incidence path
dat         <- sim$datasets[[1]] # data
total_cases <- colSums(dat[,-1]) # for priors as in Web Appendix G of Fintzi. et al. (2020)
require(ggplot2)
require(cowplot)

sim_dat = 
  data.frame(Country = rep(c("Guinea", "Liberia", "Sierra Leone"), each = tmax),
             Week = seq_len(tmax),
             Type = rep(c("True incidence", "Observed incidence"), each = 3*tmax),
             Count = 
               c(c(true_path[-1,c("E_guin2I_guin","E_lib2I_lib", "E_sln2I_sln")]),
               c(dat[,c("guin_cases", "lib_cases", "sln_cases")])))

sim_dat = subset(sim_dat,
                 !(Country == "Liberia" & Week < 10) & 
                   !(Country == "Sierra Leone" & Week < 19))

ggplot(sim_dat, 
       aes(x = Week, y = Count, shape = Type, colour = Country, fill = Type)) +
  geom_point(size = 2, alpha = 0.6) + 
  scale_shape_manual(breaks = c("True incidence", "Observed incidence"),
                     values = c(1, 16)) + 
  scale_colour_brewer(type = "qual", palette = 6) + 
  theme_minimal() + 
  # theme(legend.position = "bottom") + 
  facet_grid(.~Country)

Fitting the model in stemr via the LNA or ODE approximations

We’ll need to recompile the model dynamics and measurement process, both to tweak the model to now estimate the effective population size and to let the package know that we’re going to be fitting the model to data. After that, we’ll specify priors along, along with functions for going to and from the MCMC estimation scale. Details of the priors are given in Web Appendix G of Fintzi, et al. (2020).

set.seed(12511)

# tweaking the rates to accommodate an adjustment to the effective population 
# size. effpop_guin, effpop_lib, effpop_sln are the size of the susceptible 
# populations in each country that are effectively removed from transmission.
# So, for example, S_guin - effpop_guin is equal to the effective number of 
# susceptibles in Guinea

rates <-
  list(
    rate(
      rate = "beta_guin * (I_guin + transmission_lib * alpha_lib2guin * I_lib + transmission_sln * alpha_sln2guin * I_sln) * (S_guin - effpop_guin)",
      from = "S",
      to = "E",
      strata = "guin",
      lumped = TRUE,
      incidence = T
    ),
    rate(
      rate = "transmission_lib * beta_lib * (I_lib + alpha_guin2lib * I_guin + transmission_sln * alpha_sln2lib * I_sln) * (S_lib - effpop_lib)",
      from = "S",
      to = "E",
      strata = "lib",
      lumped = TRUE,
      incidence = T
    ),
    rate(
      rate = "transmission_sln * beta_sln * (I_sln + alpha_guin2sln * I_guin + transmission_lib * alpha_lib2sln * I_lib) * (S_sln - effpop_sln)",
      from = "S",
      to = "E",
      strata = "sln",
      lumped = TRUE,
      incidence = T
    ),
    rate(
      rate = "omega_guin",
      from = "E",
      to = "I",
      strata = "guin",
      incidence = T
    ),
    rate(
      rate = "transmission_lib * omega_lib",
      from = "E",
      to = "I",
      strata = "lib",
      incidence = T
    ),
    rate(
      rate = "transmission_sln * omega_sln",
      from = "E",
      to = "I",
      strata = "sln",
      incidence = T
    ),
    rate(
      rate = "mu_guin",
      from = "I",
      to = "R",
      strata = "guin",
      incidence = T
    ),
    rate(
      rate = "transmission_lib * mu_lib",
      from = "I",
      to = "R",
      strata = "lib",
      incidence = T
    ),
    rate(
      rate = "transmission_sln * mu_sln",
      from = "I",
      to = "R",
      strata = "sln",
      incidence = T
    )
  )

# initial states, now inferred as parameters in the model
state_initializer <-
  list(
    stem_initializer(
      c(
        S_guin = popsize_guin - 30,
        E_guin = 15,
        I_guin = 10,
        R_guin = 5
      ),
      fixed = FALSE,
      strata = "guin",
      prior = c(popsize_guin - 30, 15, 10, 5)
    ),
    stem_initializer(
      c(
        S_lib = popsize_lib - 30,
        E_lib = 15,
        I_lib = 10,
        R_lib = 5
      ),
      fixed = FALSE,
      strata = "lib",
      prior = c(popsize_lib - 30, 15, 10, 5)
    ),
    stem_initializer(
      c(
        S_sln = popsize_sln - 30,
        E_sln = 15,
        I_sln = 10,
        R_sln = 5
      ),
      fixed = FALSE,
      strata = "sln",
      prior = c(popsize_sln - 30, 15, 10, 5)
    )
  )

# same assumptions about transmission commencing in liberia and sierra leone
tcovar <- cbind(time = 0:tmax,
                transmission_lib = c(rep(0, 9), rep(1, (tmax+1) - 9)),
                transmission_sln = c(rep(0, 18), rep(1, (tmax+1) - 18)))

# parameters
parameters = c(beta_guin = 1.25 / ep_guin * mu_guin, 
               beta_lib = 1.35 / ep_lib * mu_lib,
               beta_sln = 1.45 / ep_sln * mu_sln,
               alpha_guin2lib = 0.02 / ep_lib * mu_guin,
               alpha_guin2sln = 0.02 / ep_sln * mu_guin,
               alpha_lib2guin = 0.02 / ep_guin * mu_lib,
               alpha_lib2sln  = 0.02 / ep_sln * mu_lib,
               alpha_sln2guin = 0.02 / ep_guin * mu_sln,
               alpha_sln2lib  = 0.02 / ep_lib * mu_sln,
               effpop_guin = popsize_guin - ep_guin,
               effpop_lib = popsize_lib - ep_lib,
               effpop_sln = popsize_sln - ep_sln,
               omega_guin = 1,
               omega_lib = 1,
               omega_sln = 1,
               mu_guin = mu_guin,
               mu_lib = mu_lib,
               mu_sln = mu_sln,
               rho_guin = 100/150,
               rho_lib = 100/175,
               rho_sln = 100/125,
               phi_guin = 50,
               phi_lib = 50,
               phi_sln = 50)

# declare t0 as a constant
constants <- c(t0 = 0)
t0 <- 0; tmax <- nrow(dat);

# compile the dynamics
dynamics <-
      stem_dynamics(
            rates = rates,
            tmax = tmax,
            timestep = NULL,
            parameters = parameters,
            state_initializer = state_initializer,
            compartments = compartments,
            constants = constants,
            strata = strata,
            tcovar = tcovar,
            messages = F,
            compile_ode = T,
            compile_rates = T,
            compile_lna = T,
            rtol = 1e-5,
            atol = 1e-5,
            step_size = 1e-7
      )

# compile the measurement process
emissions <-
  list(
    emission(
      meas_var = "guin_cases",
      distribution = "negbinomial",
      emission_params = c("phi_guin", "E_guin2I_guin * rho_guin"),
      incidence = TRUE,
      obstimes = seq(1, tmax, by = 1)
    ),
    emission(
      meas_var = "lib_cases",
      distribution = "negbinomial",
      emission_params = c("phi_lib", "E_lib2I_lib * rho_lib"),
      incidence = TRUE,
      obstimes = seq(10, tmax, by = 1)
    ),
    emission(
      meas_var = "sln_cases",
      distribution = "negbinomial",
      emission_params = c("phi_sln", "E_sln2I_sln * rho_sln"),
      incidence = TRUE,
      obstimes = seq(19, tmax, by = 1)
    )
  )

# recompile stemr object
measurement_process <-
  stem_measure(
    data = dat,
    emissions = emissions,
    dynamics = dynamics,
    messages = F
  )

stem_object <- make_stem(dynamics = dynamics, 
                    measurement_process = measurement_process)

#### initialize the inference
popsizes <- c(popsize_guin, popsize_lib, popsize_sln)

# to and from the MCMC estimation scale
to_est_scale_guin <- function(params_nat) {
  
  l_effpop     <- log(popsize_guin - params_nat[["effpop_guin"]])
  l_Neff_x_rho <- l_effpop + log(params_nat[["rho_guin"]])
  l_infec_dur  <- -log(params_nat[["mu_guin"]])
  l_Reff_m1    <- log(expm1(log(params_nat[["beta_guin"]]) + 
                              l_effpop + l_infec_dur))
  
  return(c(
    log_Reff_guin_o = l_Reff_m1 + l_Neff_x_rho,
    log_Rext_g2l = l_effpop + l_infec_dur + log(params_nat[["alpha_guin2lib"]]),
    log_Rext_g2s = l_effpop + l_infec_dur + log(params_nat[["alpha_guin2sln"]]),
    log_Neff_x_rho_guin = l_Neff_x_rho,
    log_omega_d_mu_guin = log(params_nat[["omega_guin"]]) + l_infec_dur,
    log_infec_dur_guin = l_infec_dur,
    logit_rho_guin = qlogis(params_nat[["rho_guin"]]),
    log_sqrt_phi_inv_guin = -0.5 * log(params_nat[["phi_guin"]])
  ))
}

from_est_scale_guin <- function(params_est) {
  
  rho       <- plogis(params_est[["logit_rho_guin"]])
  l_effpop  <- params_est[["log_Neff_x_rho_guin"]] - log(rho)
  l_Reff_m1 <- params_est[["log_Reff_guin_o"]] - params_est[["log_Neff_x_rho_guin"]]
  
  return(c(
    beta_guin = exp(log1p(exp(l_Reff_m1)) - l_effpop -
                      params_est[["log_infec_dur_guin"]]),
    alpha_guin2lib = 
      exp(params_est[["log_Rext_g2l"]] - 
            l_effpop - params_est[["log_infec_dur_guin"]]),
    alpha_guin2sln = 
      exp(params_est[["log_Rext_g2s"]] - 
            l_effpop - params_est[["log_infec_dur_guin"]]),
    effpop_guin = popsize_guin - exp(l_effpop),
    omega_guin = 
      exp(params_est[["log_omega_d_mu_guin"]] - params_est[["log_infec_dur_guin"]]),
    mu_guin = 
      exp(-params_est[["log_infec_dur_guin"]]),
    rho_guin = 
      plogis(params_est[["logit_rho_guin"]]),
    phi_guin = 
      exp(-2 * params_est[["log_sqrt_phi_inv_guin"]])
  ))
}

logprior_guin <- function(params_est) {
  
  rho      <- plogis(params_est[["logit_rho_guin"]])
  l_effpop <- params_est[["log_Neff_x_rho_guin"]] - log(rho)
  
  sum(dnorm(params_est[["log_Reff_guin_o"]] - params_est[["log_Neff_x_rho_guin"]], 
            log(0.5), 1.08, log = TRUE),
      dexp(exp(params_est[["log_Rext_g2l"]]), 40, log = TRUE) +
        params_est[["log_Rext_g2l"]], 
      dexp(exp(params_est[["log_Rext_g2s"]]), 40, log = TRUE) +
        params_est[["log_Rext_g2s"]],
      dnorm(l_effpop, 9.8, 0.62, log = TRUE),
      dnorm(params_est[["log_omega_d_mu_guin"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["log_infec_dur_guin"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["logit_rho_guin"]], 0.85, 0.75, log = TRUE), 
      dexp(exp(params_est[["log_sqrt_phi_inv_guin"]]), 0.69, log = TRUE) + 
        params_est[["log_sqrt_phi_inv_guin"]]
      )
}

to_est_scale_lib <- function(params_nat) {
  
  l_effpop     <- log(popsize_lib - params_nat[["effpop_lib"]])
  l_Neff_x_rho <- l_effpop + log(params_nat[["rho_lib"]])
  l_infec_dur  <- -log(params_nat[["mu_lib"]])
  l_Reff_m1    <- log(expm1(log(params_nat[["beta_lib"]]) + 
                              l_effpop + l_infec_dur))
  
  return(c(
    log_Reff_lib_o = l_Reff_m1 + l_Neff_x_rho,
    log_Rext_l2g = l_effpop + l_infec_dur + log(params_nat[["alpha_lib2guin"]]),
    log_Rext_l2s = l_effpop + l_infec_dur + log(params_nat[["alpha_lib2sln"]]),
    log_Neff_x_rho_lib = l_Neff_x_rho,
    log_omega_d_mu_lib = log(params_nat[["omega_lib"]]) + l_infec_dur,
    log_infec_dur_lib = l_infec_dur,
    logit_rho_lib = qlogis(params_nat[["rho_lib"]]),
    log_sqrt_phi_inv_lib = -0.5 * log(params_nat[["phi_lib"]])
  ))
}

from_est_scale_lib <- function(params_est) {
  
  rho       <- plogis(params_est[["logit_rho_lib"]])
  l_effpop  <- params_est[["log_Neff_x_rho_lib"]] - log(rho)
  l_Reff_m1 <- params_est[["log_Reff_lib_o"]] - params_est[["log_Neff_x_rho_lib"]]
  
  return(c(
    beta_lib = exp(log1p(exp(l_Reff_m1)) - l_effpop -
                      params_est[["log_infec_dur_lib"]]),
    alpha_lib2guin = 
      exp(params_est[["log_Rext_l2g"]] - 
            l_effpop - params_est[["log_infec_dur_lib"]]),
    alpha_lib2sln = 
      exp(params_est[["log_Rext_l2s"]] - 
            l_effpop - params_est[["log_infec_dur_lib"]]),
    effpop_lib = popsize_lib - exp(l_effpop),
    omega_lib = 
      exp(params_est[["log_omega_d_mu_lib"]] - params_est[["log_infec_dur_lib"]]),
    mu_lib = 
      exp(-params_est[["log_infec_dur_lib"]]),
    rho_lib = 
      plogis(params_est[["logit_rho_lib"]]),
    phi_lib = 
      exp(-2 * params_est[["log_sqrt_phi_inv_lib"]])
  ))
}

logprior_lib <- function(params_est) {
  rho      <- plogis(params_est[["logit_rho_lib"]])
  l_effpop <- params_est[["log_Neff_x_rho_lib"]] - log(rho)
  
  sum(dnorm(params_est[["log_Reff_lib_o"]] - params_est[["log_Neff_x_rho_lib"]], 
            log(0.5), 1.08, log = TRUE),
      dexp(exp(params_est[["log_Rext_l2g"]]), 40, log = TRUE) +
        params_est[["log_Rext_l2g"]], 
      dexp(exp(params_est[["log_Rext_l2s"]]), 40, log = TRUE) +
        params_est[["log_Rext_l2s"]],
      dnorm(l_effpop, 10.5, 0.62, log = TRUE),
      dnorm(params_est[["log_omega_d_mu_lib"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["log_infec_dur_lib"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["logit_rho_lib"]], 0.85, 0.75, log = TRUE), 
      dexp(exp(params_est[["log_sqrt_phi_inv_lib"]]), 0.69, log = TRUE) + 
        params_est[["log_sqrt_phi_inv_lib"]]
      )
}

to_est_scale_sln <- function(params_nat) {
  
  l_effpop     <- log(popsize_sln - params_nat[["effpop_sln"]])
  l_Neff_x_rho <- l_effpop + log(params_nat[["rho_sln"]])
  l_infec_dur  <- -log(params_nat[["mu_sln"]])
  l_Reff_m1    <- log(expm1(log(params_nat[["beta_sln"]]) + 
                              l_effpop + l_infec_dur))
  
  return(c(
    log_Reff_sln_o = l_Reff_m1 + l_Neff_x_rho,
    log_Rext_s2g = l_effpop + l_infec_dur + log(params_nat[["alpha_sln2guin"]]),
    log_Rext_s2l = l_effpop + l_infec_dur + log(params_nat[["alpha_sln2lib"]]),
    log_Neff_x_rho_sln = l_Neff_x_rho,
    log_omega_d_mu_sln = log(params_nat[["omega_sln"]]) + l_infec_dur,
    log_infec_dur_sln = l_infec_dur,
    logit_rho_sln = qlogis(params_nat[["rho_sln"]]),
    log_sqrt_phi_inv_sln = -0.5 * log(params_nat[["phi_sln"]])
  ))
}

from_est_scale_sln <- function(params_est) {
  
  rho       <- plogis(params_est[["logit_rho_sln"]])
  l_effpop  <- params_est[["log_Neff_x_rho_sln"]] - log(rho)
  l_Reff_m1 <- params_est[["log_Reff_sln_o"]] - params_est[["log_Neff_x_rho_sln"]]
  
  return(c(
    beta_sln = exp(log1p(exp(l_Reff_m1)) - l_effpop -
                      params_est[["log_infec_dur_sln"]]),
    alpha_sln2guin = 
      exp(params_est[["log_Rext_s2g"]] - 
            l_effpop - params_est[["log_infec_dur_sln"]]),
    alpha_sln2lib = 
      exp(params_est[["log_Rext_s2l"]] - 
            l_effpop - params_est[["log_infec_dur_sln"]]),
    effpop_sln = popsize_sln - exp(l_effpop),
    omega_sln = 
      exp(params_est[["log_omega_d_mu_sln"]] - params_est[["log_infec_dur_sln"]]),
    mu_sln = 
      exp(-params_est[["log_infec_dur_sln"]]),
    rho_sln = 
      plogis(params_est[["logit_rho_sln"]]),
    phi_sln = 
      exp(-2 * params_est[["log_sqrt_phi_inv_sln"]])
  ))
}

logprior_sln <- function(params_est) {
  rho      <- plogis(params_est[["logit_rho_sln"]])
  l_effpop <- params_est[["log_Neff_x_rho_sln"]] - log(rho)
  
  sum(dnorm(params_est[["log_Reff_sln_o"]] - params_est[["log_Neff_x_rho_sln"]], 
            log(0.5), 1.08, log = TRUE),
      dexp(exp(params_est[["log_Rext_s2g"]]), 40, log = TRUE) +
        params_est[["log_Rext_s2g"]], 
      dexp(exp(params_est[["log_Rext_s2l"]]), 40, log = TRUE) +
        params_est[["log_Rext_s2l"]],
      dnorm(l_effpop, 10.6, 0.62, log = TRUE),
      dnorm(params_est[["log_omega_d_mu_sln"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["log_infec_dur_sln"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["logit_rho_sln"]], 0.85, 0.75, log = TRUE), 
      dexp(exp(params_est[["log_sqrt_phi_inv_sln"]]), 0.69, log = TRUE) + 
        params_est[["log_sqrt_phi_inv_sln"]]
      )
}

mcmc_kern <- 
  mcmc_kernel(
    parameter_blocks = 
      list(
        parblock(pars_nat = 
                   c("beta_guin", 
                     "alpha_guin2lib",
                     "alpha_guin2sln",
                     "effpop_guin",
                     "omega_guin",
                     "mu_guin",
                     "rho_guin",
                     "phi_guin"),
                 pars_est = 
                   c("log_Reff_guin_o",
                     "log_Rext_g2l",
                     "log_Rext_g2s",
                     "log_Neff_x_rho_guin",
                     "log_omega_d_mu_guin",
                     "log_infec_dur_guin",
                     "logit_rho_guin",
                     "log_sqrt_phi_inv_guin"),
                 priors = 
                   list(logprior = logprior_guin,
                        to_estimation_scale = to_est_scale_guin,
                        from_estimation_scale = from_est_scale_guin),
                 alg = "mvnss",
                 sigma = diag(c(0.05, 0.1, 0.1, 1, 0.05, 0.05, 0.5, 0.05)),
                 initializer = 
                   function() {
                     from_est_scale_guin(
                       to_est_scale_guin(
                         parameters[c("beta_guin", 
                                      "alpha_guin2lib",
                                      "alpha_guin2sln",
                                      "effpop_guin",
                                      "omega_guin",
                                      "mu_guin",
                                      "rho_guin",
                                      "phi_guin")]) + rnorm(8, 0, 0.1))
                   }, 
                 control = 
                   mvnss_control(scale_constant = 0.5,
                                 scale_cooling = 0.7,
                                 stop_adaptation = 5e4,
                                 step_size = 0.5,
                                 nugget = 1e-5, 
                                 nugget_step_size = 0.001,
                                 nugget_cooling = 0.99)),
        parblock(pars_nat = 
                   c("beta_lib", 
                     "alpha_lib2guin",
                     "alpha_lib2sln",
                     "effpop_lib",
                     "omega_lib",
                     "mu_lib",
                     "rho_lib",
                     "phi_lib"),
                 pars_est = 
                   c("log_Reff_lib_o",
                     "log_Rext_l2g",
                     "log_Rext_l2s",
                     "log_Neff_x_rho_lib",
                     "log_omega_d_mu_lib",
                     "log_infec_dur_lib",
                     "logit_rho_lib",
                     "log_sqrt_phi_inv_lib"),
                 priors = 
                   list(logprior = logprior_lib,
                        to_estimation_scale = to_est_scale_lib,
                        from_estimation_scale = from_est_scale_lib),
                 alg = "mvnss",
                 sigma = diag(c(0.05, 0.1, 0.1, 1, 0.05, 0.05, 0.5, 0.05)),
                 initializer = 
                   function() {
                     from_est_scale_lib(
                       to_est_scale_lib(
                         parameters[c("beta_lib", 
                                      "alpha_lib2guin",
                                      "alpha_lib2sln",
                                      "effpop_lib",
                                      "omega_lib",
                                      "mu_lib",
                                      "rho_lib",
                                      "phi_lib")]) + rnorm(8, 0, 0.1))
                   }, 
                 control = 
                   mvnss_control(scale_constant = 0.5,
                                 scale_cooling = 0.7,
                                 stop_adaptation = 5e4,
                                 step_size = 0.5,
                                 nugget = 1e-5, 
                                 nugget_step_size = 0.001,
                                 nugget_cooling = 0.99)),
        parblock(pars_nat = 
                   c("beta_sln", 
                     "alpha_sln2guin",
                     "alpha_sln2lib",
                     "effpop_sln",
                     "omega_sln",
                     "mu_sln",
                     "rho_sln",
                     "phi_sln"),
                 pars_est = 
                   c("log_Reff_sln_o",
                     "log_Rext_s2g",
                     "log_Rext_s2l",
                     "log_Neff_x_rho_sln",
                     "log_omega_d_mu_sln",
                     "log_infec_dur_sln",
                     "logit_rho_sln",
                     "log_sqrt_phi_inv_sln"),
                 priors = 
                   list(logprior = logprior_sln,
                        to_estimation_scale = to_est_scale_sln,
                        from_estimation_scale = from_est_scale_sln),
                 alg = "mvnss",
                 sigma = diag(c(0.05, 0.1, 0.1, 1, 0.05, 0.05, 0.5, 0.05)),
                 initializer = 
                   function() {
                     from_est_scale_sln(
                       to_est_scale_sln(
                         parameters[c("beta_sln", 
                                      "alpha_sln2guin",
                                      "alpha_sln2lib",
                                      "effpop_sln",
                                      "omega_sln",
                                      "mu_sln",
                                      "rho_sln",
                                      "phi_sln")]) + rnorm(8, 0, 0.1))
                   }, 
                 control = 
                   mvnss_control(scale_constant = 0.5,
                                 scale_cooling = 0.7,
                                 stop_adaptation = 5e4,
                                 step_size = 0.5,
                                 nugget = 1e-5, 
                                 nugget_step_size = 0.001,
                                 nugget_cooling = 0.99))),
    lna_ess_control = lna_control(bracket_update_iter = 5e3,
                                  joint_initdist_update = FALSE))

We are now ready to fit the model. For the purpose of this vignette, we’ll set the inference method to “ode”. Fitting the model via the LNA only requires that the method be set to “lna”.

# fit the model
res <- 
  fit_stem(stem_object = stem_object,
                 method = "ode", # or set to "ode"
                 iterations = 1.5e5,
                 thinning_interval = 50,
                 mcmc_kern = mcmc_kern,
                 print_progress = 0)

The stem_inference function returns a list of MCMC samples, latent epidemic paths, and MCMC tuning parameters (e.g., global scaling parameter adapted in the MCMC). These can be accessed as follows:

runtime = res$results$runtime
posterior = res$results$posterior # list with posterior objects

Fitting the model to data from the 2014-2015 West Africa outbreak

National level WHO incidence counts from the 2014-2015 Ebola outbreak in West Africa are included in the package and may be accessed via data("ebola") after the package is loaded. The only changes vis-a-vis the simulation code shown above were to the outbreak length (73 vs. 70 weeks) and to the priors for the effective population sizes, which were tuned by inflating the observed outbreak sizes by some crude estimates of the case detection ratio to match the expected outbreak sizes under our priors for the basic reproduction number. This is further explained in Web Appendices of Fintzi, et al. (2020).

data("ebola")
dat = ebola; rm(ebola) # just renaming
colnames(dat)[2:4] = c("guin_cases", "lib_cases", "sln_cases")
tmax <- nrow(dat)

popsize_guin <- 11.8e6
popsize_lib <- 4.4e6
popsize_sln <- 7.1e6

log_popsize_guin <- log(popsize_guin)
log_popsize_lib <- log(popsize_lib)
log_popsize_sln <- log(popsize_sln)

# stemr object separate dynamics for all countries --------------------------------------------------
set.seed(12511)
strata <- c("guin", "lib", "sln")
compartments <- list(S = "ALL", E = "ALL", I = "ALL", R = "ALL")

# rates
rates <-
  list(
    rate(
      "beta_guin * (I_guin + transmission_lib * alpha_lib2guin * I_lib + transmission_sln * alpha_sln2guin * I_sln) * (S_guin - effpop_guin)",
      from = "S",
      to = "E",
      strata = "guin",
      lumped = TRUE,
      incidence = T
    ),
    rate(
      "transmission_lib * beta_lib * (I_lib + alpha_guin2lib * I_guin + transmission_sln * alpha_sln2lib * I_sln) * (S_lib - effpop_lib)",
      from = "S",
      to = "E",
      strata = "lib",
      lumped = TRUE,
      incidence = T
    ),
    rate(
      "transmission_sln * beta_sln * (I_sln + alpha_guin2sln * I_guin + transmission_lib * alpha_lib2sln * I_lib) * (S_sln - effpop_sln)",
      from = "S",
      to = "E",
      strata = "sln",
      lumped = TRUE,
      incidence = T
    ),
    rate(
      "omega_guin",
      from = "E",
      to = "I",
      strata = "guin",
      incidence = T
    ),
    rate(
      "transmission_lib * omega_lib",
      from = "E",
      to = "I",
      strata = "lib",
      incidence = T
    ),
    rate(
      "transmission_sln * omega_sln",
      from = "E",
      to = "I",
      strata = "sln",
      incidence = T
    ),
    rate(
      "mu_guin",
      from = "I",
      to = "R",
      strata = "guin",
      incidence = T
    ),
    rate(
      "transmission_lib * mu_lib",
      from = "I",
      to = "R",
      strata = "lib",
      incidence = T
    ),
    rate(
      "transmission_sln * mu_sln",
      from = "I",
      to = "R",
      strata = "sln",
      incidence = T
    )
  )

# initial state

state_initializer <-
  list(
    stem_initializer(
      c(
        S_guin = popsize_guin - 30,
        E_guin = 15,
        I_guin = 10,
        R_guin = 5
      ),
      fixed = FALSE,
      strata = "guin",
      prior = c(popsize_guin - 30, 15, 10, 5)
    ),
    stem_initializer(
      c(
        S_lib = popsize_lib - 30,
        E_lib = 15,
        I_lib = 10,
        R_lib = 5
      ),
      fixed = FALSE,
      strata = "lib",
      prior = c(popsize_lib - 30, 15, 10, 5)
    ),
    stem_initializer(
      c(
        S_sln = popsize_sln - 30,
        E_sln = 15,
        I_sln = 10,
        R_sln = 5
      ),
      fixed = FALSE,
      strata = "sln",
      prior = c(popsize_sln - 30, 15, 10, 5)
    )
  )

# time varying covariates
tcovar <- cbind(time = 0:tmax,
                transmission_lib = c(rep(0, 9), rep(1, (tmax+1) - 9)),
                transmission_sln = c(rep(0, 18), rep(1, (tmax+1) - 18)))

# effective population sizes
ep_guin <- 1.5e4
ep_lib  <- 3e4
ep_sln  <- 4.5e4

parameters = c(beta_guin = 1.25 / ep_guin,
               beta_lib = 1.5 / ep_lib,
               beta_sln = 1.5 / ep_sln,
               alpha_guin2lib = 0.02 / ep_lib,
               alpha_guin2sln = 0.02 / ep_sln,
               alpha_lib2guin = 0.02 / ep_guin,
               alpha_lib2sln  = 0.02 / ep_sln,
               alpha_sln2guin = 0.02 / ep_guin,
               alpha_sln2lib  = 0.02 / ep_lib,
               effpop_guin = popsize_guin - ep_guin,
               effpop_lib = popsize_lib - ep_lib,
               effpop_sln = popsize_sln - ep_sln,
               omega_guin = 1,
               omega_lib = 1,
               omega_sln = 1,
               mu_guin = 1,
               mu_lib = 1,
               mu_sln = 1,
               rho_guin = 0.75,
               rho_lib = 0.75,
               rho_sln = 0.75,
               phi_guin = 10,
               phi_lib = 10,
               phi_sln = 10)

constants <- c(t0 = 0)
t0 <- 0; tmax <- nrow(dat);

dynamics <-
      stem_dynamics(
            rates = rates,
            tmax = tmax,
            timestep = NULL,
            parameters = parameters,
            state_initializer = state_initializer,
            compartments = compartments,
            constants = constants,
            strata = strata,
            tcovar = tcovar,
            messages = F,
            compile_ode = T,
            compile_rates = T,
            compile_lna = T,
            rtol = 1e-5,
            atol = 1e-5,
            step_size = 1e-6
      )

# emission distribution
emissions <-
  list(
    emission(
      "guin_cases",
      "negbinomial",
      c("phi_guin", "E_guin2I_guin * rho_guin"),
      incidence = TRUE,
      obstimes = seq(1, tmax, by = 1)
    ),
    emission(
      "lib_cases",
      "negbinomial",
      c("phi_lib", "E_lib2I_lib * rho_lib"),
      incidence = TRUE,
      obstimes = seq(10, tmax, by = 1)
    ),
    emission(
      "sln_cases",
      "negbinomial",
      c("phi_sln", "E_sln2I_sln * rho_sln"),
      incidence = TRUE,
      obstimes = seq(19, tmax, by = 1)
    )
  )

# recompile stemr object
measurement_process <- 
  stem_measure(data = dat, 
               emissions = emissions, 
               dynamics = dynamics, 
               messages = F)

stem_object <- 
  make_stem(dynamics = dynamics, 
            measurement_process = measurement_process)

#### initialize the MCMC objects
popsizes     <- c(popsize_guin, popsize_lib, popsize_sln)

# to and from the MCMC estimation scale
to_est_scale_guin <- function(params_nat) {
  
  l_effpop     <- log(popsize_guin - params_nat[["effpop_guin"]])
  l_Neff_x_rho <- l_effpop + log(params_nat[["rho_guin"]])
  l_infec_dur  <- -log(params_nat[["mu_guin"]])
  l_Reff_m1    <- log(expm1(log(params_nat[["beta_guin"]]) + 
                              l_effpop + l_infec_dur))
  
  return(c(
    log_Reff_guin_o = l_Reff_m1 + l_Neff_x_rho,
    log_Rext_g2l = l_effpop + l_infec_dur + log(params_nat[["alpha_guin2lib"]]),
    log_Rext_g2s = l_effpop + l_infec_dur + log(params_nat[["alpha_guin2sln"]]),
    log_Neff_x_rho_guin = l_Neff_x_rho,
    log_omega_d_mu_guin = log(params_nat[["omega_guin"]]) + l_infec_dur,
    log_infec_dur_guin = l_infec_dur,
    logit_rho_guin = qlogis(params_nat[["rho_guin"]]),
    log_sqrt_phi_inv_guin = -0.5 * log(params_nat[["phi_guin"]])
  ))
}

from_est_scale_guin <- function(params_est) {
  
  rho       <- plogis(params_est[["logit_rho_guin"]])
  l_effpop  <- params_est[["log_Neff_x_rho_guin"]] - log(rho)
  l_Reff_m1 <- params_est[["log_Reff_guin_o"]] - params_est[["log_Neff_x_rho_guin"]]
  
  return(c(
    beta_guin = exp(log1p(exp(l_Reff_m1)) - l_effpop -
                      params_est[["log_infec_dur_guin"]]),
    alpha_guin2lib = 
      exp(params_est[["log_Rext_g2l"]] - 
            l_effpop - params_est[["log_infec_dur_guin"]]),
    alpha_guin2sln = 
      exp(params_est[["log_Rext_g2s"]] - 
            l_effpop - params_est[["log_infec_dur_guin"]]),
    effpop_guin = popsize_guin - exp(l_effpop),
    omega_guin = 
      exp(params_est[["log_omega_d_mu_guin"]] - params_est[["log_infec_dur_guin"]]),
    mu_guin = 
      exp(-params_est[["log_infec_dur_guin"]]),
    rho_guin = 
      plogis(params_est[["logit_rho_guin"]]),
    phi_guin = 
      exp(-2 * params_est[["log_sqrt_phi_inv_guin"]])
  ))
}

logprior_guin <- function(params_est) {
  
  rho      <- plogis(params_est[["logit_rho_guin"]])
  l_effpop <- params_est[["log_Neff_x_rho_guin"]] - log(rho)
  
  sum(dnorm(params_est[["log_Reff_guin_o"]] - params_est[["log_Neff_x_rho_guin"]], 
            log(0.5), 1.08, log = TRUE),
      dexp(exp(params_est[["log_Rext_g2l"]]), 40, log = TRUE) +
        params_est[["log_Rext_g2l"]], 
      dexp(exp(params_est[["log_Rext_g2s"]]), 40, log = TRUE) +
        params_est[["log_Rext_g2s"]],
      dnorm(l_effpop, 9.6, 0.62, log = TRUE),
      dnorm(params_est[["log_omega_d_mu_guin"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["log_infec_dur_guin"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["logit_rho_guin"]], 0.85, 0.75, log = TRUE), 
      dexp(exp(params_est[["log_sqrt_phi_inv_guin"]]), 0.69, log = TRUE) + 
        params_est[["log_sqrt_phi_inv_guin"]]
      )
}

to_est_scale_lib <- function(params_nat) {
  
  l_effpop     <- log(popsize_lib - params_nat[["effpop_lib"]])
  l_Neff_x_rho <- l_effpop + log(params_nat[["rho_lib"]])
  l_infec_dur  <- -log(params_nat[["mu_lib"]])
  l_Reff_m1    <- log(expm1(log(params_nat[["beta_lib"]]) + 
                              l_effpop + l_infec_dur))
  
  return(c(
    log_Reff_lib_o = l_Reff_m1 + l_Neff_x_rho,
    log_Rext_l2g = l_effpop + l_infec_dur + log(params_nat[["alpha_lib2guin"]]),
    log_Rext_l2s = l_effpop + l_infec_dur + log(params_nat[["alpha_lib2sln"]]),
    log_Neff_x_rho_lib = l_Neff_x_rho,
    log_omega_d_mu_lib = log(params_nat[["omega_lib"]]) + l_infec_dur,
    log_infec_dur_lib = l_infec_dur,
    logit_rho_lib = qlogis(params_nat[["rho_lib"]]),
    log_sqrt_phi_inv_lib = -0.5 * log(params_nat[["phi_lib"]])
  ))
}

from_est_scale_lib <- function(params_est) {
  
  rho       <- plogis(params_est[["logit_rho_lib"]])
  l_effpop  <- params_est[["log_Neff_x_rho_lib"]] - log(rho)
  l_Reff_m1 <- params_est[["log_Reff_lib_o"]] - params_est[["log_Neff_x_rho_lib"]]
  
  return(c(
    beta_lib = exp(log1p(exp(l_Reff_m1)) - l_effpop -
                      params_est[["log_infec_dur_lib"]]),
    alpha_lib2guin = 
      exp(params_est[["log_Rext_l2g"]] - 
            l_effpop - params_est[["log_infec_dur_lib"]]),
    alpha_lib2sln = 
      exp(params_est[["log_Rext_l2s"]] - 
            l_effpop - params_est[["log_infec_dur_lib"]]),
    effpop_lib = popsize_lib - exp(l_effpop),
    omega_lib = 
      exp(params_est[["log_omega_d_mu_lib"]] - params_est[["log_infec_dur_lib"]]),
    mu_lib = 
      exp(-params_est[["log_infec_dur_lib"]]),
    rho_lib = 
      plogis(params_est[["logit_rho_lib"]]),
    phi_lib = 
      exp(-2 * params_est[["log_sqrt_phi_inv_lib"]])
  ))
}

logprior_lib <- function(params_est) {
  rho      <- plogis(params_est[["logit_rho_lib"]])
  l_effpop <- params_est[["log_Neff_x_rho_lib"]] - log(rho)
  
  sum(dnorm(params_est[["log_Reff_lib_o"]] - params_est[["log_Neff_x_rho_lib"]], 
            log(0.5), 1.08, log = TRUE),
      dexp(exp(params_est[["log_Rext_l2g"]]), 40, log = TRUE) +
        params_est[["log_Rext_l2g"]], 
      dexp(exp(params_est[["log_Rext_l2s"]]), 40, log = TRUE) +
        params_est[["log_Rext_l2s"]],
      dnorm(l_effpop, 9.9, 0.62, log = TRUE),
      dnorm(params_est[["log_omega_d_mu_lib"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["log_infec_dur_lib"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["logit_rho_lib"]], 0.85, 0.75, log = TRUE), 
      dexp(exp(params_est[["log_sqrt_phi_inv_lib"]]), 0.69, log = TRUE) + 
        params_est[["log_sqrt_phi_inv_lib"]]
      )
}

to_est_scale_sln <- function(params_nat) {
  
  l_effpop     <- log(popsize_sln - params_nat[["effpop_sln"]])
  l_Neff_x_rho <- l_effpop + log(params_nat[["rho_sln"]])
  l_infec_dur  <- -log(params_nat[["mu_sln"]])
  l_Reff_m1    <- log(expm1(log(params_nat[["beta_sln"]]) + 
                              l_effpop + l_infec_dur))
  
  return(c(
    log_Reff_sln_o = l_Reff_m1 + l_Neff_x_rho,
    log_Rext_s2g = l_effpop + l_infec_dur + log(params_nat[["alpha_sln2guin"]]),
    log_Rext_s2l = l_effpop + l_infec_dur + log(params_nat[["alpha_sln2lib"]]),
    log_Neff_x_rho_sln = l_Neff_x_rho,
    log_omega_d_mu_sln = log(params_nat[["omega_sln"]]) + l_infec_dur,
    log_infec_dur_sln = l_infec_dur,
    logit_rho_sln = qlogis(params_nat[["rho_sln"]]),
    log_sqrt_phi_inv_sln = -0.5 * log(params_nat[["phi_sln"]])
  ))
}

from_est_scale_sln <- function(params_est) {
  
  rho       <- plogis(params_est[["logit_rho_sln"]])
  l_effpop  <- params_est[["log_Neff_x_rho_sln"]] - log(rho)
  l_Reff_m1 <- params_est[["log_Reff_sln_o"]] - params_est[["log_Neff_x_rho_sln"]]
  
  return(c(
    beta_sln = exp(log1p(exp(l_Reff_m1)) - l_effpop -
                      params_est[["log_infec_dur_sln"]]),
    alpha_sln2guin = 
      exp(params_est[["log_Rext_s2g"]] - 
            l_effpop - params_est[["log_infec_dur_sln"]]),
    alpha_sln2lib = 
      exp(params_est[["log_Rext_s2l"]] - 
            l_effpop - params_est[["log_infec_dur_sln"]]),
    effpop_sln = popsize_sln - exp(l_effpop),
    omega_sln = 
      exp(params_est[["log_omega_d_mu_sln"]] - params_est[["log_infec_dur_sln"]]),
    mu_sln = 
      exp(-params_est[["log_infec_dur_sln"]]),
    rho_sln = 
      plogis(params_est[["logit_rho_sln"]]),
    phi_sln = 
      exp(-2 * params_est[["log_sqrt_phi_inv_sln"]])
  ))
}

logprior_sln <- function(params_est) {
  rho      <- plogis(params_est[["logit_rho_sln"]])
  l_effpop <- params_est[["log_Neff_x_rho_sln"]] - log(rho)
  
  sum(dnorm(params_est[["log_Reff_sln_o"]] - params_est[["log_Neff_x_rho_sln"]], 
            log(0.5), 1.08, log = TRUE),
      dexp(exp(params_est[["log_Rext_s2g"]]), 40, log = TRUE) +
        params_est[["log_Rext_s2g"]], 
      dexp(exp(params_est[["log_Rext_s2l"]]), 40, log = TRUE) +
        params_est[["log_Rext_s2l"]],
      dnorm(l_effpop, 10.7, 0.62, log = TRUE),
      dnorm(params_est[["log_omega_d_mu_sln"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["log_infec_dur_sln"]], 0, 0.3, log = TRUE),
      dnorm(params_est[["logit_rho_sln"]], 0.85, 0.75, log = TRUE), 
      dexp(exp(params_est[["log_sqrt_phi_inv_sln"]]), 0.69, log = TRUE) + 
        params_est[["log_sqrt_phi_inv_sln"]]
      )
}

mcmc_kern <- 
  mcmc_kernel(
    parameter_blocks = 
      list(
        parblock(pars_nat = 
                   c("beta_guin", 
                     "alpha_guin2lib",
                     "alpha_guin2sln",
                     "effpop_guin",
                     "omega_guin",
                     "mu_guin",
                     "rho_guin",
                     "phi_guin"),
                 pars_est = 
                   c("log_Reff_guin_o",
                     "log_Rext_g2l",
                     "log_Rext_g2s",
                     "log_Neff_x_rho_guin",
                     "log_omega_d_mu_guin",
                     "log_infec_dur_guin",
                     "logit_rho_guin",
                     "log_sqrt_phi_inv_guin"),
                 priors = 
                   list(logprior = logprior_guin,
                        to_estimation_scale = to_est_scale_guin,
                        from_estimation_scale = from_est_scale_guin),
                 alg = "mvnss",
                 sigma = diag(c(0.05, 0.1, 0.1, 1, 0.05, 0.05, 0.5, 0.05)),
                 initializer = 
                   function() {
                     from_est_scale_guin(
                       to_est_scale_guin(
                         parameters[c("beta_guin", 
                                      "alpha_guin2lib",
                                      "alpha_guin2sln",
                                      "effpop_guin",
                                      "omega_guin",
                                      "mu_guin",
                                      "rho_guin",
                                      "phi_guin")]) + rnorm(8, 0, 0.1))
                   }, 
                 control = 
                   mvnss_control(scale_constant = 0.5,
                                 scale_cooling = 0.7,
                                 stop_adaptation = 5e4,
                                 step_size = 0.5,
                                 nugget = 1e-5, 
                                 nugget_step_size = 0.001,
                                 nugget_cooling = 0.99)),
        parblock(pars_nat = 
                   c("beta_lib", 
                     "alpha_lib2guin",
                     "alpha_lib2sln",
                     "effpop_lib",
                     "omega_lib",
                     "mu_lib",
                     "rho_lib",
                     "phi_lib"),
                 pars_est = 
                   c("log_Reff_lib_o",
                     "log_Rext_l2g",
                     "log_Rext_l2s",
                     "log_Neff_x_rho_lib",
                     "log_omega_d_mu_lib",
                     "log_infec_dur_lib",
                     "logit_rho_lib",
                     "log_sqrt_phi_inv_lib"),
                 priors = 
                   list(logprior = logprior_lib,
                        to_estimation_scale = to_est_scale_lib,
                        from_estimation_scale = from_est_scale_lib),
                 alg = "mvnss",
                 sigma = diag(c(0.05, 0.1, 0.1, 1, 0.05, 0.05, 0.5, 0.05)),
                 initializer = 
                   function() {
                     from_est_scale_lib(
                       to_est_scale_lib(
                         parameters[c("beta_lib", 
                                      "alpha_lib2guin",
                                      "alpha_lib2sln",
                                      "effpop_lib",
                                      "omega_lib",
                                      "mu_lib",
                                      "rho_lib",
                                      "phi_lib")]) + rnorm(8, 0, 0.1))
                   }, 
                 control = 
                   mvnss_control(scale_constant = 0.5,
                                 scale_cooling = 0.7,
                                 stop_adaptation = 5e4,
                                 step_size = 0.5,
                                 nugget = 1e-5, 
                                 nugget_step_size = 0.001,
                                 nugget_cooling = 0.99)),
        parblock(pars_nat = 
                   c("beta_sln", 
                     "alpha_sln2guin",
                     "alpha_sln2lib",
                     "effpop_sln",
                     "omega_sln",
                     "mu_sln",
                     "rho_sln",
                     "phi_sln"),
                 pars_est = 
                   c("log_Reff_sln_o",
                     "log_Rext_s2g",
                     "log_Rext_s2l",
                     "log_Neff_x_rho_sln",
                     "log_omega_d_mu_sln",
                     "log_infec_dur_sln",
                     "logit_rho_sln",
                     "log_sqrt_phi_inv_sln"),
                 priors = 
                   list(logprior = logprior_sln,
                        to_estimation_scale = to_est_scale_sln,
                        from_estimation_scale = from_est_scale_sln),
                 alg = "mvnss",
                 sigma = diag(c(0.05, 0.1, 0.1, 1, 0.05, 0.05, 0.5, 0.05)),
                 initializer = 
                   function() {
                     from_est_scale_sln(
                       to_est_scale_sln(
                         parameters[c("beta_sln", 
                                      "alpha_sln2guin",
                                      "alpha_sln2lib",
                                      "effpop_sln",
                                      "omega_sln",
                                      "mu_sln",
                                      "rho_sln",
                                      "phi_sln")]) + rnorm(8, 0, 0.1))
                   }, 
                 control = 
                   mvnss_control(scale_constant = 0.5,
                                 scale_cooling = 0.7,
                                 stop_adaptation = 5e4,
                                 step_size = 0.5,
                                 nugget = 1e-5, 
                                 nugget_step_size = 0.001,
                                 nugget_cooling = 0.99))),
    lna_ess_control = lna_control(bracket_update_iter = 5e3,
                                  joint_initdist_update = FALSE))


# fit the model
res <- 
  fit_stem(stem_object = stem_object,
                 method = "ode",
                 iterations = 1.5e5, 
                 thinning_interval = 100,
                 mcmc_kern = mcmc_kern)

References

Fintzi, J., Wakefield, J., & Minin, V. N. (2020). A linear noise approximation for stochastic epidemic models fit to partially observed incidence counts. arXiv preprint arXiv:2001.05099.