Skip to content

API Reference

smcjax

Sequential Monte Carlo and particle filtering in JAX.

LiuWestPosterior

Bases: NamedTuple

Full output of a Liu-West particle filter run.

Extends :class:ParticleFilterPosterior with parameter samples. The Liu-West filter (Liu & West, 2001) jointly estimates latent states and static parameters using kernel density smoothing.

Attributes:

Name Type Description
marginal_loglik Scalar

Scalar estimate of :math:\log p(y_{1:T}).

filtered_particles Float[Array, 'ntime num_particles state_dim']

Particle values at each time step, shape (ntime, num_particles, state_dim).

filtered_log_weights Float[Array, 'ntime num_particles']

Unnormalized log weights at each step, shape (ntime, num_particles).

ancestors Int[Array, 'ntime num_particles']

Resampled ancestor indices at each time step, shape (ntime, num_particles).

ess Float[Array, ' ntime']

Effective sample size at each time step, shape (ntime,).

log_evidence_increments Float[Array, ' ntime']

Per-step log marginal likelihood increments, shape (ntime,). These sum to marginal_loglik.

filtered_params Float[Array, 'ntime num_particles param_dim']

Parameter samples at each time step, shape (ntime, num_particles, param_dim).

Source code in smcjax/containers.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class LiuWestPosterior(NamedTuple):
    r"""Full output of a Liu-West particle filter run.

    Extends :class:`ParticleFilterPosterior` with parameter samples.
    The Liu-West filter (Liu & West, 2001) jointly estimates latent
    states and static parameters using kernel density smoothing.

    Attributes:
        marginal_loglik: Scalar estimate of
            :math:`\log p(y_{1:T})`.
        filtered_particles: Particle values at each time step,
            shape ``(ntime, num_particles, state_dim)``.
        filtered_log_weights: Unnormalized log weights at each step,
            shape ``(ntime, num_particles)``.
        ancestors: Resampled ancestor indices at each time step,
            shape ``(ntime, num_particles)``.
        ess: Effective sample size at each time step,
            shape ``(ntime,)``.
        log_evidence_increments: Per-step log marginal likelihood
            increments, shape ``(ntime,)``.  These sum to
            ``marginal_loglik``.
        filtered_params: Parameter samples at each time step,
            shape ``(ntime, num_particles, param_dim)``.
    """

    marginal_loglik: Scalar
    filtered_particles: Float[Array, 'ntime num_particles state_dim']
    filtered_log_weights: Float[Array, 'ntime num_particles']
    ancestors: Int[Array, 'ntime num_particles']
    ess: Float[Array, ' ntime']
    log_evidence_increments: Float[Array, ' ntime']
    filtered_params: Float[Array, 'ntime num_particles param_dim']

ParticleFilterPosterior

Bases: NamedTuple

Full output of a particle filter run.

Follows the Dynamax PosteriorGSSMFiltered convention of storing the marginal log-likelihood as a scalar summary alongside the time-indexed arrays.

Attributes:

Name Type Description
marginal_loglik Scalar

Scalar estimate of :math:\log p(y_{1:T}).

filtered_particles Float[Array, 'ntime num_particles state_dim']

Particle values at each time step, shape (ntime, num_particles, state_dim).

filtered_log_weights Float[Array, 'ntime num_particles']

Unnormalized log weights at each time step, shape (ntime, num_particles).

ancestors Int[Array, 'ntime num_particles']

Resampled ancestor indices at each time step, shape (ntime, num_particles).

ess Float[Array, ' ntime']

Effective sample size at each time step, shape (ntime,).

log_evidence_increments Float[Array, ' ntime']

Per-step log marginal likelihood increments, shape (ntime,). These sum to marginal_loglik.

Source code in smcjax/containers.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class ParticleFilterPosterior(NamedTuple):
    r"""Full output of a particle filter run.

    Follows the Dynamax ``PosteriorGSSMFiltered`` convention of storing
    the marginal log-likelihood as a scalar summary alongside the
    time-indexed arrays.

    Attributes:
        marginal_loglik: Scalar estimate of
            :math:`\log p(y_{1:T})`.
        filtered_particles: Particle values at each time step,
            shape ``(ntime, num_particles, state_dim)``.
        filtered_log_weights: Unnormalized log weights at each time step,
            shape ``(ntime, num_particles)``.
        ancestors: Resampled ancestor indices at each time step,
            shape ``(ntime, num_particles)``.
        ess: Effective sample size at each time step,
            shape ``(ntime,)``.
        log_evidence_increments: Per-step log marginal likelihood
            increments, shape ``(ntime,)``.  These sum to
            ``marginal_loglik``.
    """

    marginal_loglik: Scalar
    filtered_particles: Float[Array, 'ntime num_particles state_dim']
    filtered_log_weights: Float[Array, 'ntime num_particles']
    ancestors: Int[Array, 'ntime num_particles']
    ess: Float[Array, ' ntime']
    log_evidence_increments: Float[Array, ' ntime']

ParticleFilterResult

Bases: Protocol

Structural type for any particle filter posterior.

Both :class:ParticleFilterPosterior and :class:LiuWestPosterior satisfy this protocol, so diagnostic functions can accept either without type errors.

Attributes:

Name Type Description
marginal_loglik Scalar

Scalar estimate of :math:\log p(y_{1:T}).

filtered_particles Float[Array, 'ntime num_particles state_dim']

Particle values at each time step, shape (ntime, num_particles, state_dim).

filtered_log_weights Float[Array, 'ntime num_particles']

Normalised log weights at each step, shape (ntime, num_particles).

ancestors Int[Array, 'ntime num_particles']

Resampled ancestor indices at each time step, shape (ntime, num_particles).

ess Float[Array, ' ntime']

Effective sample size at each time step, shape (ntime,).

log_evidence_increments Float[Array, ' ntime']

Per-step log marginal likelihood increments, shape (ntime,).

Source code in smcjax/containers.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@runtime_checkable
class ParticleFilterResult(Protocol):
    r"""Structural type for any particle filter posterior.

    Both :class:`ParticleFilterPosterior` and
    :class:`LiuWestPosterior` satisfy this protocol, so diagnostic
    functions can accept either without type errors.

    Attributes:
        marginal_loglik: Scalar estimate of
            :math:`\log p(y_{1:T})`.
        filtered_particles: Particle values at each time step,
            shape ``(ntime, num_particles, state_dim)``.
        filtered_log_weights: Normalised log weights at each step,
            shape ``(ntime, num_particles)``.
        ancestors: Resampled ancestor indices at each time step,
            shape ``(ntime, num_particles)``.
        ess: Effective sample size at each time step,
            shape ``(ntime,)``.
        log_evidence_increments: Per-step log marginal likelihood
            increments, shape ``(ntime,)``.
    """

    marginal_loglik: Scalar
    filtered_particles: Float[Array, 'ntime num_particles state_dim']
    filtered_log_weights: Float[Array, 'ntime num_particles']
    ancestors: Int[Array, 'ntime num_particles']
    ess: Float[Array, ' ntime']
    log_evidence_increments: Float[Array, ' ntime']

ParticleState

Bases: NamedTuple

State of a particle cloud at a single time step.

Attributes:

Name Type Description
particles Float[Array, 'num_particles state_dim']

Particle values, shape (num_particles, state_dim).

log_weights Float[Array, ' num_particles']

Unnormalized log importance weights, shape (num_particles,).

log_marginal_likelihood Scalar

Running log marginal likelihood estimate.

Source code in smcjax/containers.py
47
48
49
50
51
52
53
54
55
56
57
58
59
class ParticleState(NamedTuple):
    r"""State of a particle cloud at a single time step.

    Attributes:
        particles: Particle values, shape ``(num_particles, state_dim)``.
        log_weights: Unnormalized log importance weights,
            shape ``(num_particles,)``.
        log_marginal_likelihood: Running log marginal likelihood estimate.
    """

    particles: Float[Array, 'num_particles state_dim']
    log_weights: Float[Array, ' num_particles']
    log_marginal_likelihood: Scalar

auxiliary_filter(key, initial_sampler, transition_sampler, log_observation_fn, log_auxiliary_fn, emissions, num_particles, resampling_fn=systematic, resampling_threshold=0.5)

Run an auxiliary particle filter (Pitt & Shephard, 1999).

Parameters:

Name Type Description Default
key PRNGKeyT

JAX PRNG key.

required
initial_sampler Callable

Function (key, num_particles) -> particles that draws from the initial state distribution :math:p(z_1).

required
transition_sampler Callable

Function (key, state) -> state that draws from the transition distribution :math:p(z_t \mid z_{t-1}). Will be vmap-ped over the particle dimension internally.

required
log_observation_fn Callable

Function (emission, state) -> log_prob that evaluates the observation log-density :math:\log p(y_t \mid z_t). Will be vmap-ped over the particle dimension (second argument) internally.

required
log_auxiliary_fn Callable

Function (emission, state) -> log_prob that evaluates the look-ahead log-density :math:\log g(y_{t+1} \mid x_t). Will be vmap-ped over the particle dimension (second argument) internally. When this returns zero for all inputs the APF reduces to the bootstrap filter.

required
emissions Float[Array, 'ntime emission_dim']

Observed emissions, shape (T, D).

required
num_particles int

Number of particles :math:N.

required
resampling_fn Callable

Resampling algorithm matching the Blackjax signature (key, weights, num_samples) -> indices. Defaults to :func:~blackjax.smc.resampling.systematic.

systematic
resampling_threshold float

Fraction of num_particles below which resampling is triggered (e.g. 0.5 means resample when ESS < 0.5 * N).

0.5

Returns:

Type Description
ParticleFilterPosterior

class:~smcjax.containers.ParticleFilterPosterior containing

ParticleFilterPosterior

filtered particles, log weights, ancestor indices, the

ParticleFilterPosterior

marginal log-likelihood estimate, and ESS trace.

Source code in smcjax/auxiliary.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def auxiliary_filter(
    key: PRNGKeyT,
    initial_sampler: Callable,
    transition_sampler: Callable,
    log_observation_fn: Callable,
    log_auxiliary_fn: Callable,
    emissions: Float[Array, 'ntime emission_dim'],
    num_particles: int,
    resampling_fn: Callable = systematic,
    resampling_threshold: float = 0.5,
) -> ParticleFilterPosterior:
    r"""Run an auxiliary particle filter (Pitt & Shephard, 1999).

    Args:
        key: JAX PRNG key.
        initial_sampler: Function ``(key, num_particles) -> particles``
            that draws from the initial state distribution
            :math:`p(z_1)`.
        transition_sampler: Function ``(key, state) -> state`` that
            draws from the transition distribution
            :math:`p(z_t \mid z_{t-1})`.  Will be ``vmap``-ped over
            the particle dimension internally.
        log_observation_fn: Function
            ``(emission, state) -> log_prob`` that evaluates the
            observation log-density :math:`\log p(y_t \mid z_t)`.
            Will be ``vmap``-ped over the particle dimension (second
            argument) internally.
        log_auxiliary_fn: Function
            ``(emission, state) -> log_prob`` that evaluates the
            look-ahead log-density
            :math:`\log g(y_{t+1} \mid x_t)`.
            Will be ``vmap``-ped over the particle dimension (second
            argument) internally.  When this returns zero for all
            inputs the APF reduces to the bootstrap filter.
        emissions: Observed emissions, shape ``(T, D)``.
        num_particles: Number of particles :math:`N`.
        resampling_fn: Resampling algorithm matching the Blackjax
            signature ``(key, weights, num_samples) -> indices``.
            Defaults to :func:`~blackjax.smc.resampling.systematic`.
        resampling_threshold: Fraction of ``num_particles`` below
            which resampling is triggered (e.g. 0.5 means resample
            when ``ESS < 0.5 * N``).

    Returns:
        :class:`~smcjax.containers.ParticleFilterPosterior` containing
        filtered particles, log weights, ancestor indices, the
        marginal log-likelihood estimate, and ESS trace.
    """
    key, init_key = jr.split(key)
    log_n = jnp.log(jnp.asarray(num_particles, dtype=jnp.float64))

    # --- Initialise at t=0 -------------------------------------------------
    (
        particles_0,
        log_w_0,
        log_ev_0,
        ess_0,
        identity_ancestors,
        init_state,
    ) = _init_standard(
        init_key,
        initial_sampler,
        log_observation_fn,
        emissions[0],
        num_particles,
        log_n,
    )

    # --- Scan body for t = 1, ..., T-1 -------------------------------------
    def _step(
        carry: ParticleState,
        args: tuple[PRNGKeyT, Float[Array, ' emission_dim']],
    ) -> tuple[ParticleState, tuple[Array, Array, Array, Array, Array]]:
        state, (step_key, y_t) = carry, args
        k1, k2 = jr.split(step_key)
        # Invariant: state.log_weights are normalized (logsumexp = 0).

        # 1. First-stage weights: combine current weights with
        #    look-ahead g(y_{t+1} | x_t)
        log_aux = vmap(lambda z: log_auxiliary_fn(y_t, z))(state.particles)
        log_first_stage = state.log_weights + log_aux

        # Normalise first-stage weights for resampling
        log_first_norm, log_first_sum = log_normalize(log_first_stage)

        # 2. Conditionally resample using first-stage weights
        threshold = resampling_threshold * num_particles
        do_resample, ancestors = _conditional_resample(
            k1,
            log_first_norm,
            resampling_fn,
            threshold,
            num_particles,
            identity_ancestors,
        )
        resampled_particles = state.particles[ancestors]

        # Store the look-ahead values for ancestors (needed for
        # second-stage correction)
        log_aux_ancestors = log_aux[ancestors]

        # 3. Propagate through transition
        keys = jr.split(k2, num_particles)
        propagated = vmap(transition_sampler)(keys, resampled_particles)

        # 4. Second-stage weights: observation / look-ahead adjustment
        log_obs = vmap(lambda z: log_observation_fn(y_t, z))(propagated)
        log_second_stage = log_obs - log_aux_ancestors

        # Compute evidence increment and normalize.
        # If resampled: first-stage weights were used for resampling,
        #   the evidence increment is the product of two factors:
        #   (a) E_W[g] = sum_i W_i * g_i  (first-stage normaliser)
        #   (b) (1/N) sum_j w_j^(2)       (mean second-stage weight)
        # If not resampled: standard importance weighting,
        #   increment = logsumexp(log_w_old + log_obs)
        log_w_unnorm = jnp.where(
            do_resample,
            log_second_stage,
            state.log_weights + log_obs,
        )
        log_w_norm, log_sum = log_normalize(log_w_unnorm)

        # log_first_sum = logsumexp(log_w_norm_old + log_aux)
        #   = log(sum W_i g_i) = log E_W[g]  (no 1/N needed)
        # log_sum for second stage = logsumexp(log_second_stage)
        #   so mean = log_sum - log_n
        log_ev_inc_resample = log_first_sum + log_sum - log_n
        log_ev_inc_no_resample = log_sum
        log_ev_inc = jnp.where(
            do_resample, log_ev_inc_resample, log_ev_inc_no_resample
        )

        new_state = ParticleState(
            particles=propagated,
            log_weights=log_w_norm,
            log_marginal_likelihood=(
                state.log_marginal_likelihood + log_ev_inc
            ),
        )
        ess_t: Array = jnp.asarray(compute_ess(log_w_norm))
        return new_state, (
            propagated,
            log_w_norm,
            ancestors,
            ess_t,
            log_ev_inc,
        )

    # Run the scan over t = 1 ... T-1
    step_keys = jr.split(key, emissions.shape[0] - 1)
    (
        final_state,
        (
            particles_rest,
            log_w_rest,
            ancestors_rest,
            ess_rest,
            log_ev_inc_rest,
        ),
    ) = lax.scan(_step, init_state, (step_keys, emissions[1:]))

    # --- Combine t=0 with t=1..T-1 -----------------------------------------
    all_particles = _prepend(particles_0, particles_rest)
    all_log_w = _prepend(log_w_0, log_w_rest)
    all_ancestors = _prepend(identity_ancestors, ancestors_rest)
    ess_0_arr: Array = jnp.asarray(ess_0)
    all_ess = _prepend(ess_0_arr, ess_rest)
    all_log_ev_inc = _prepend(jnp.asarray(log_ev_0), log_ev_inc_rest)

    return ParticleFilterPosterior(
        marginal_loglik=final_state.log_marginal_likelihood,
        filtered_particles=all_particles,
        filtered_log_weights=all_log_w,
        ancestors=all_ancestors,
        ess=all_ess,
        log_evidence_increments=all_log_ev_inc,
    )

bootstrap_filter(key, initial_sampler, transition_sampler, log_observation_fn, emissions, num_particles, resampling_fn=systematic, resampling_threshold=0.5)

Run a bootstrap (SIR) particle filter.

Parameters:

Name Type Description Default
key PRNGKeyT

JAX PRNG key.

required
initial_sampler Callable

Function (key, num_particles) -> particles that draws from the initial state distribution :math:p(z_1).

required
transition_sampler Callable

Function (key, state) -> state that draws from the transition distribution :math:p(z_t \mid z_{t-1}). Will be vmap-ped over the particle dimension internally.

required
log_observation_fn Callable

Function (emission, state) -> log_prob that evaluates the observation log-density :math:\log p(y_t \mid z_t). Will be vmap-ped over the particle dimension (second argument) internally.

required
emissions Float[Array, 'ntime emission_dim']

Observed emissions, shape (T, D).

required
num_particles int

Number of particles :math:N.

required
resampling_fn Callable

Resampling algorithm matching the Blackjax signature (key, weights, num_samples) -> indices. Defaults to :func:~smcjax.resampling.systematic.

systematic
resampling_threshold float

Fraction of num_particles below which resampling is triggered (e.g. 0.5 means resample when ESS < 0.5 * N).

0.5

Returns:

Type Description
ParticleFilterPosterior

class:~smcjax.containers.ParticleFilterPosterior containing

ParticleFilterPosterior

filtered particles, log weights, ancestor indices, the

ParticleFilterPosterior

marginal log-likelihood estimate, and ESS trace.

Source code in smcjax/bootstrap.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def bootstrap_filter(
    key: PRNGKeyT,
    initial_sampler: Callable,
    transition_sampler: Callable,
    log_observation_fn: Callable,
    emissions: Float[Array, 'ntime emission_dim'],
    num_particles: int,
    resampling_fn: Callable = systematic,
    resampling_threshold: float = 0.5,
) -> ParticleFilterPosterior:
    r"""Run a bootstrap (SIR) particle filter.

    Args:
        key: JAX PRNG key.
        initial_sampler: Function ``(key, num_particles) -> particles``
            that draws from the initial state distribution
            :math:`p(z_1)`.
        transition_sampler: Function ``(key, state) -> state`` that draws
            from the transition distribution
            :math:`p(z_t \mid z_{t-1})`.  Will be ``vmap``-ped over the
            particle dimension internally.
        log_observation_fn: Function
            ``(emission, state) -> log_prob`` that evaluates the
            observation log-density :math:`\log p(y_t \mid z_t)`.
            Will be ``vmap``-ped over the particle dimension (second
            argument) internally.
        emissions: Observed emissions, shape ``(T, D)``.
        num_particles: Number of particles :math:`N`.
        resampling_fn: Resampling algorithm matching the Blackjax
            signature ``(key, weights, num_samples) -> indices``.
            Defaults to :func:`~smcjax.resampling.systematic`.
        resampling_threshold: Fraction of ``num_particles`` below which
            resampling is triggered (e.g. 0.5 means resample when
            ``ESS < 0.5 * N``).

    Returns:
        :class:`~smcjax.containers.ParticleFilterPosterior` containing
        filtered particles, log weights, ancestor indices, the
        marginal log-likelihood estimate, and ESS trace.
    """
    key, init_key = jr.split(key)
    log_n = jnp.log(jnp.asarray(num_particles, dtype=jnp.float64))

    # --- Initialise at t=0 -------------------------------------------------
    (
        particles_0,
        log_w_0,
        log_ev_0,
        ess_0,
        identity_ancestors,
        init_state,
    ) = _init_standard(
        init_key,
        initial_sampler,
        log_observation_fn,
        emissions[0],
        num_particles,
        log_n,
    )

    # --- Scan body for t = 1, ..., T-1 -------------------------------------
    def _step(
        carry: ParticleState,
        args: tuple[PRNGKeyT, Float[Array, ' emission_dim']],
    ) -> tuple[ParticleState, tuple[Array, Array, Array, Array, Array]]:
        state, (step_key, y_t) = carry, args
        k1, k2 = jr.split(step_key)
        # Invariant: state.log_weights are normalized (logsumexp = 0).

        # 1. Conditionally resample
        threshold = resampling_threshold * num_particles
        do_resample, ancestors = _conditional_resample(
            k1,
            state.log_weights,
            resampling_fn,
            threshold,
            num_particles,
            identity_ancestors,
        )
        resampled_particles = state.particles[ancestors]

        # 2. Propagate through transition
        keys = jr.split(k2, num_particles)
        propagated = vmap(transition_sampler)(keys, resampled_particles)

        # 3. Weight by observation likelihood
        log_obs = vmap(lambda z: log_observation_fn(y_t, z))(propagated)

        # Compute evidence increment and normalize.
        # If resampled: weights were reset to uniform (1/N), so
        #   increment = logsumexp(log_obs) - log(N)
        # If not resampled: old normalized weights W_i sum to 1, so
        #   increment = logsumexp(log_W + log_obs)
        log_w_unnorm = jnp.where(
            do_resample,
            log_obs,
            state.log_weights + log_obs,
        )
        log_w_norm, log_sum = log_normalize(log_w_unnorm)
        log_ev_inc = jnp.where(
            do_resample,
            log_sum - log_n,
            log_sum,
        )

        new_state = ParticleState(
            particles=propagated,
            log_weights=log_w_norm,
            log_marginal_likelihood=(
                state.log_marginal_likelihood + log_ev_inc
            ),
        )
        ess_t: Array = jnp.asarray(compute_ess(log_w_norm))
        return new_state, (
            propagated,
            log_w_norm,
            ancestors,
            ess_t,
            log_ev_inc,
        )

    # Run the scan over t = 1 ... T-1
    step_keys = jr.split(key, emissions.shape[0] - 1)
    (
        final_state,
        (
            particles_rest,
            log_w_rest,
            ancestors_rest,
            ess_rest,
            log_ev_inc_rest,
        ),
    ) = lax.scan(_step, init_state, (step_keys, emissions[1:]))

    # --- Combine t=0 with t=1..T-1 -----------------------------------------
    all_particles = _prepend(particles_0, particles_rest)
    all_log_w = _prepend(log_w_0, log_w_rest)
    all_ancestors = _prepend(identity_ancestors, ancestors_rest)
    ess_0_arr: Array = jnp.asarray(ess_0)
    all_ess = _prepend(ess_0_arr, ess_rest)
    all_log_ev_inc = _prepend(jnp.asarray(log_ev_0), log_ev_inc_rest)

    return ParticleFilterPosterior(
        marginal_loglik=final_state.log_marginal_likelihood,
        filtered_particles=all_particles,
        filtered_log_weights=all_log_w,
        ancestors=all_ancestors,
        ess=all_ess,
        log_evidence_increments=all_log_ev_inc,
    )

crps(predictions, observation)

Compute the Continuous Ranked Probability Score.

CRPS is a proper scoring rule for probabilistic forecasts:

.. math::

\text{CRPS} = \mathbb{E}|Y - y|
             - \tfrac{1}{2}\,\mathbb{E}|Y - Y'|

where :math:Y, Y' are iid predictive samples and :math:y is the observation.

Parameters:

Name Type Description Default
predictions Float[Array, ' num_samples']

iid samples from the predictive distribution.

required
observation Scalar

Observed scalar value.

required

Returns:

Type Description
Scalar

Scalar CRPS (lower is better, zero for perfect prediction).

Source code in smcjax/diagnostics.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def crps(
    predictions: Float[Array, ' num_samples'],
    observation: Scalar,
) -> Scalar:
    r"""Compute the Continuous Ranked Probability Score.

    CRPS is a proper scoring rule for probabilistic forecasts:

    .. math::

        \text{CRPS} = \mathbb{E}|Y - y|
                     - \tfrac{1}{2}\,\mathbb{E}|Y - Y'|

    where :math:`Y, Y'` are iid predictive samples and :math:`y`
    is the observation.

    Args:
        predictions: iid samples from the predictive distribution.
        observation: Observed scalar value.

    Returns:
        Scalar CRPS (lower is better, zero for perfect prediction).
    """
    obs = jnp.asarray(observation)
    n = predictions.shape[0]
    # E|Y - y|
    term1 = jnp.mean(jnp.abs(predictions - obs))
    # E|Y - Y'| via sort-based O(N log N) identity:
    #   E|Y-Y'| = (2 / N^2) * sum_i (2i - N + 1) * Y_{(i)}
    y_sorted = jnp.sort(predictions)
    i = jnp.arange(n, dtype=predictions.dtype)
    term2 = 2.0 * jnp.sum((2.0 * i - n + 1.0) * y_sorted) / (n * n)
    return jnp.asarray(term1 - 0.5 * term2)

cumulative_log_score(posterior)

Compute the cumulative one-step-ahead predictive log-score.

The log-evidence increments :math:\log p(y_t \mid y_{1:t-1}) are already one-step-ahead predictive log-densities. This function returns their running sum:

.. math::

S_t = \sum_{s=1}^{t} \log p(y_s \mid y_{1:s-1})

so that :math:S_T equals the total marginal log-likelihood. Comparing :math:S_t across models gives a time-resolved predictive comparison that, unlike Bayes factors, is less sensitive to the prior.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required

Returns:

Type Description
Float[Array, ' ntime']

Cumulative log-scores, shape (ntime,).

Source code in smcjax/diagnostics.py
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
def cumulative_log_score(
    posterior: ParticleFilterResult,
) -> Float[Array, ' ntime']:
    r"""Compute the cumulative one-step-ahead predictive log-score.

    The log-evidence increments :math:`\log p(y_t \mid y_{1:t-1})`
    are already one-step-ahead predictive log-densities.  This
    function returns their running sum:

    .. math::

        S_t = \sum_{s=1}^{t} \log p(y_s \mid y_{1:s-1})

    so that :math:`S_T` equals the total marginal log-likelihood.
    Comparing :math:`S_t` across models gives a time-resolved
    predictive comparison that, unlike Bayes factors, is less
    sensitive to the prior.

    Args:
        posterior: Particle filter posterior output.

    Returns:
        Cumulative log-scores, shape ``(ntime,)``.
    """
    return jnp.cumsum(posterior.log_evidence_increments)

diagnose(posterior, ess_threshold=0.1, diversity_threshold=0.1, pareto_k_threshold=0.7)

Summarise filter health and flag potential problems.

Runs a battery of diagnostics and returns a dictionary with scalar summaries and a list of plain-text warnings. The thresholds are configurable; the defaults flag situations where the particle approximation is likely unreliable.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required
ess_threshold float

Fraction of N below which ESS triggers a warning.

0.1
diversity_threshold float

Diversity below which a warning is triggered.

0.1
pareto_k_threshold float

Pareto-k above which a warning is triggered.

0.7

Returns:

Type Description
dict[str, Any]

Dictionary with keys:

dict[str, Any]
  • min_ess: minimum ESS across all time steps
dict[str, Any]
  • min_diversity: minimum particle diversity
dict[str, Any]
  • max_pareto_k: maximum Pareto-k across time steps
dict[str, Any]
  • min_tail_ess: minimum tail-ESS across time steps
dict[str, Any]
  • ess_below_threshold: count of steps where ESS < ess_threshold * N
dict[str, Any]
  • warnings: list of diagnostic warning strings
Source code in smcjax/diagnostics.py
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
def diagnose(
    posterior: ParticleFilterResult,
    ess_threshold: float = 0.1,
    diversity_threshold: float = 0.1,
    pareto_k_threshold: float = 0.7,
) -> dict[str, Any]:
    r"""Summarise filter health and flag potential problems.

    Runs a battery of diagnostics and returns a dictionary with
    scalar summaries and a list of plain-text warnings.  The
    thresholds are configurable; the defaults flag situations
    where the particle approximation is likely unreliable.

    Args:
        posterior: Particle filter posterior output.
        ess_threshold: Fraction of N below which ESS triggers a
            warning.
        diversity_threshold: Diversity below which a warning is
            triggered.
        pareto_k_threshold: Pareto-k above which a warning is
            triggered.

    Returns:
        Dictionary with keys:

        - ``min_ess``: minimum ESS across all time steps
        - ``min_diversity``: minimum particle diversity
        - ``max_pareto_k``: maximum Pareto-k across time steps
        - ``min_tail_ess``: minimum tail-ESS across time steps
        - ``ess_below_threshold``: count of steps where
            ESS < ``ess_threshold * N``
        - ``warnings``: list of diagnostic warning strings
    """
    n = posterior.filtered_particles.shape[1]
    ess_vals = posterior.ess
    diversity = particle_diversity(posterior)
    k_hat = pareto_k_diagnostic(posterior)
    t_ess = tail_ess(posterior)

    min_ess = float(jnp.min(ess_vals))
    min_div = float(jnp.min(diversity))
    max_k = float(jnp.max(k_hat))
    min_t_ess = float(jnp.min(t_ess))
    ess_count = int(jnp.sum(ess_vals < ess_threshold * n))

    warnings: list[str] = []
    if min_ess < ess_threshold * n:
        warnings.append(
            f'ESS dropped below {ess_threshold:.0%} of N '
            f'at {ess_count} step(s) (min ESS = {min_ess:.1f})'
        )
    if min_div < diversity_threshold:
        warnings.append(
            f'Particle diversity fell below '
            f'{diversity_threshold:.0%} (min = {min_div:.3f})'
        )
    if max_k > pareto_k_threshold:
        warnings.append(
            f'Pareto-k exceeded {pareto_k_threshold} '
            f'(max k = {max_k:.3f}); importance weights '
            f'have infinite variance at some steps'
        )

    return {
        'min_ess': min_ess,
        'min_diversity': min_div,
        'max_pareto_k': max_k,
        'min_tail_ess': min_t_ess,
        'ess_below_threshold': ess_count,
        'warnings': warnings,
    }

liu_west_filter(key, initial_sampler, transition_sampler, log_observation_fn, log_auxiliary_fn, param_initial_sampler, emissions, num_particles, shrinkage=0.95, resampling_fn=systematic, resampling_threshold=0.5)

Run a Liu-West particle filter (Liu & West, 2001).

Jointly estimates latent states and static parameters using auxiliary particle filtering with kernel density smoothing for parameter propagation.

Parameters:

Name Type Description Default
key PRNGKeyT

JAX PRNG key.

required
initial_sampler Callable

Function (key, num_particles) -> particles that draws from the initial state distribution.

required
transition_sampler Callable

Function (key, state, params) -> state that draws from the transition distribution. Unlike the bootstrap/auxiliary filters, this receives parameters.

required
log_observation_fn Callable

Function (emission, state, params) -> log_prob that evaluates the observation log-density. Receives parameters.

required
log_auxiliary_fn Callable

Function (emission, state, params) -> log_prob that evaluates the look-ahead log-density. Receives parameters.

required
param_initial_sampler Callable

Function (key, num_particles) -> params that draws from the prior parameter distribution. Returns array of shape (num_particles, param_dim).

required
emissions Float[Array, 'ntime emission_dim']

Observed emissions, shape (T, D).

required
num_particles int

Number of particles :math:N.

required
shrinkage float

Shrinkage parameter :math:a \in (0, 1). Controls the balance between the kernel smoothing exploration and prior concentration. Higher values give tighter parameter posteriors.

.. warning::

The shrinkage parameter has no generative
interpretation: it introduces artificial dynamics
into the parameter evolution that do not correspond
to any probabilistic model.  Results can be
sensitive to this choice.  We recommend running the
filter under several values (e.g. 0.95, 0.975,
0.99) and reporting the range of posterior and
evidence estimates.
0.95
resampling_fn Callable

Resampling algorithm. Defaults to systematic.

systematic
resampling_threshold float

ESS fraction triggering resampling.

0.5

Returns:

Type Description
LiuWestPosterior

class:~smcjax.containers.LiuWestPosterior containing

LiuWestPosterior

filtered particles, parameters, log weights, ancestor indices,

LiuWestPosterior

the marginal log-likelihood estimate, and ESS trace.

Source code in smcjax/liu_west.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def liu_west_filter(
    key: PRNGKeyT,
    initial_sampler: Callable,
    transition_sampler: Callable,
    log_observation_fn: Callable,
    log_auxiliary_fn: Callable,
    param_initial_sampler: Callable,
    emissions: Float[Array, 'ntime emission_dim'],
    num_particles: int,
    shrinkage: float = 0.95,
    resampling_fn: Callable = systematic,
    resampling_threshold: float = 0.5,
) -> LiuWestPosterior:
    r"""Run a Liu-West particle filter (Liu & West, 2001).

    Jointly estimates latent states and static parameters using
    auxiliary particle filtering with kernel density smoothing for
    parameter propagation.

    Args:
        key: JAX PRNG key.
        initial_sampler: Function ``(key, num_particles) -> particles``
            that draws from the initial state distribution.
        transition_sampler: Function ``(key, state, params) -> state``
            that draws from the transition distribution.  Unlike the
            bootstrap/auxiliary filters, this receives parameters.
        log_observation_fn: Function
            ``(emission, state, params) -> log_prob`` that evaluates
            the observation log-density.  Receives parameters.
        log_auxiliary_fn: Function
            ``(emission, state, params) -> log_prob`` that evaluates
            the look-ahead log-density.  Receives parameters.
        param_initial_sampler: Function
            ``(key, num_particles) -> params`` that draws from the
            prior parameter distribution.  Returns array of shape
            ``(num_particles, param_dim)``.
        emissions: Observed emissions, shape ``(T, D)``.
        num_particles: Number of particles :math:`N`.
        shrinkage: Shrinkage parameter :math:`a \in (0, 1)`.
            Controls the balance between the kernel smoothing
            exploration and prior concentration.  Higher values
            give tighter parameter posteriors.

            .. warning::

                The shrinkage parameter has no generative
                interpretation: it introduces artificial dynamics
                into the parameter evolution that do not correspond
                to any probabilistic model.  Results can be
                sensitive to this choice.  We recommend running the
                filter under several values (e.g. 0.95, 0.975,
                0.99) and reporting the range of posterior and
                evidence estimates.
        resampling_fn: Resampling algorithm.  Defaults to systematic.
        resampling_threshold: ESS fraction triggering resampling.

    Returns:
        :class:`~smcjax.containers.LiuWestPosterior` containing
        filtered particles, parameters, log weights, ancestor indices,
        the marginal log-likelihood estimate, and ESS trace.
    """
    key, init_key = jr.split(key)
    log_n = jnp.log(jnp.asarray(num_particles, dtype=jnp.float64))
    a = jnp.asarray(shrinkage, dtype=jnp.float64)
    h_sq = 1.0 - a**2

    (
        particles_0,
        params_0,
        log_w_0,
        log_ev_0,
        ess_0,
        identity_ancestors,
    ) = _init_liu_west(
        init_key,
        initial_sampler,
        param_initial_sampler,
        log_observation_fn,
        emissions[0],
        num_particles,
    )

    # --- Scan body for t = 1, ..., T-1 -------------------------------------
    def _step(
        carry: _Carry,
        args: tuple[PRNGKeyT, Array],
    ) -> tuple[_Carry, tuple[Array, Array, Array, Array, Array, Array]]:
        particles, params, log_weights, log_ml = carry
        step_key, y_t = args
        k1, k2, k3 = jr.split(step_key, 3)

        # Weighted parameter moments for kernel smoothing
        w = normalize(log_weights)
        param_mean = jnp.sum(w[:, None] * params, axis=0)
        param_dev = params - param_mean[None, :]
        param_cov = jnp.einsum('n,nd,ne->de', w, param_dev, param_dev)

        # Shrunk means: m_i = a * phi_i + (1-a) * phi_bar
        shrunk = a * params + (1.0 - a) * param_mean[None, :]

        # 1. First-stage weights using shrunk params
        log_aux = vmap(lambda z, p: log_auxiliary_fn(y_t, z, p))(
            particles, shrunk
        )
        log_first_norm, log_first_sum = log_normalize(log_weights + log_aux)

        # 2. Conditionally resample
        threshold = resampling_threshold * num_particles
        do_resample, ancestors = _conditional_resample(
            k1,
            log_first_norm,
            resampling_fn,
            threshold,
            num_particles,
            identity_ancestors,
        )

        # 3. Propagate params via kernel smoothing + propagate states
        param_dim = params.shape[1]
        # Jitter prevents NaN from cholesky on singular covariance
        # (e.g. when all particles share the same parameter value).
        jitter = 1e-8 * jnp.eye(param_dim)
        chol = jnp.linalg.cholesky(h_sq * param_cov + jitter)
        eps = jr.normal(k2, (num_particles, param_dim))
        new_params = shrunk[ancestors] + eps @ chol.T

        keys = jr.split(k3, num_particles)
        propagated = vmap(transition_sampler)(
            keys,
            particles[ancestors],
            new_params,
        )

        # 4. Second-stage weights
        log_obs = vmap(lambda z, p: log_observation_fn(y_t, z, p))(
            propagated, new_params
        )
        log_w_unnorm = jnp.where(
            do_resample,
            log_obs - log_aux[ancestors],
            log_weights + log_obs,
        )
        log_w_norm, log_sum = log_normalize(log_w_unnorm)

        log_ev_inc = jnp.where(
            do_resample,
            log_first_sum + log_sum - log_n,
            log_sum,
        )

        new_carry = (propagated, new_params, log_w_norm, log_ml + log_ev_inc)
        ess_t = jnp.asarray(compute_ess(log_w_norm))
        return new_carry, (
            propagated,
            new_params,
            log_w_norm,
            ancestors,
            ess_t,
            log_ev_inc,
        )

    init_carry: _Carry = (particles_0, params_0, log_w_0, log_ev_0)
    step_keys = jr.split(key, emissions.shape[0] - 1)

    (
        final_carry,
        (
            particles_rest,
            params_rest,
            log_w_rest,
            ancestors_rest,
            ess_rest,
            log_ev_inc_rest,
        ),
    ) = lax.scan(_step, init_carry, (step_keys, emissions[1:]))

    # --- Combine t=0 with t=1..T-1 -----------------------------------------
    _, _, _, final_log_ml = final_carry

    return LiuWestPosterior(
        marginal_loglik=final_log_ml,
        filtered_particles=_prepend(particles_0, particles_rest),
        filtered_log_weights=_prepend(log_w_0, log_w_rest),
        ancestors=_prepend(identity_ancestors, ancestors_rest),
        ess=_prepend(jnp.asarray(ess_0), ess_rest),
        log_evidence_increments=_prepend(
            jnp.asarray(log_ev_0), log_ev_inc_rest
        ),
        filtered_params=_prepend(params_0, params_rest),
    )

log_bayes_factor(log_ml_1, log_ml_2)

Compute the log Bayes factor between two models.

.. math::

\log BF_{12} = \log p(y_{1:T} \mid M_1)
             - \log p(y_{1:T} \mid M_2)

Positive values favour model 1; negative values favour model 2.

.. warning::

Marginal likelihoods are sensitive to the prior in ways that
the posterior is not.  Bayes factors evaluate priors, not
posteriors (Gelman, 2023).  With weakly informative priors the
marginal likelihood is dominated by prior tails that have
little effect on posterior inference, so a Bayes factor can
reverse sign under prior changes that leave the posterior
essentially unchanged.  Consider complementing Bayes factors
with predictive comparisons (e.g. cumulative log-scores from
:func:`cumulative_log_score` or CRPS from :func:`crps`) and
use :func:`replicated_log_ml` to quantify Monte Carlo
variability.

Parameters:

Name Type Description Default
log_ml_1 Scalar

Log marginal likelihood of model 1.

required
log_ml_2 Scalar

Log marginal likelihood of model 2.

required

Returns:

Type Description
Scalar

Scalar log Bayes factor.

Source code in smcjax/diagnostics.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def log_bayes_factor(
    log_ml_1: Scalar,
    log_ml_2: Scalar,
) -> Scalar:
    r"""Compute the log Bayes factor between two models.

    .. math::

        \log BF_{12} = \log p(y_{1:T} \mid M_1)
                     - \log p(y_{1:T} \mid M_2)

    Positive values favour model 1; negative values favour model 2.

    .. warning::

        Marginal likelihoods are sensitive to the prior in ways that
        the posterior is not.  Bayes factors evaluate priors, not
        posteriors (Gelman, 2023).  With weakly informative priors the
        marginal likelihood is dominated by prior tails that have
        little effect on posterior inference, so a Bayes factor can
        reverse sign under prior changes that leave the posterior
        essentially unchanged.  Consider complementing Bayes factors
        with predictive comparisons (e.g. cumulative log-scores from
        :func:`cumulative_log_score` or CRPS from :func:`crps`) and
        use :func:`replicated_log_ml` to quantify Monte Carlo
        variability.

    Args:
        log_ml_1: Log marginal likelihood of model 1.
        log_ml_2: Log marginal likelihood of model 2.

    Returns:
        Scalar log Bayes factor.
    """
    return jnp.asarray(log_ml_1) - jnp.asarray(log_ml_2)

log_ml_increments(posterior)

Extract per-step log marginal likelihood increments.

The marginal log-likelihood can be decomposed as:

.. math::

\log p(y_{1:T}) = \sum_{t=1}^T
    \log p(y_t \mid y_{1:t-1})

This function returns the individual increments, which diagnose which observations are hardest for the model.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required

Returns:

Type Description
Float[Array, ' ntime']

Per-step evidence increments, shape (ntime,). These sum

Float[Array, ' ntime']

to posterior.marginal_loglik.

Source code in smcjax/diagnostics.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def log_ml_increments(
    posterior: ParticleFilterResult,
) -> Float[Array, ' ntime']:
    r"""Extract per-step log marginal likelihood increments.

    The marginal log-likelihood can be decomposed as:

    .. math::

        \log p(y_{1:T}) = \sum_{t=1}^T
            \log p(y_t \mid y_{1:t-1})

    This function returns the individual increments, which diagnose
    which observations are hardest for the model.

    Args:
        posterior: Particle filter posterior output.

    Returns:
        Per-step evidence increments, shape ``(ntime,)``.  These sum
        to ``posterior.marginal_loglik``.
    """
    return posterior.log_evidence_increments

log_normalize(log_weights)

Normalize log weights and return the log normalizing constant.

Parameters:

Name Type Description Default
log_weights Float[Array, ' num_particles']

Unnormalized log importance weights.

required

Returns:

Type Description
Float[Array, ' num_particles']

A tuple (log_normalized, log_normalizer) where

Scalar

log_normalized has logsumexp == 0 and

tuple[Float[Array, ' num_particles'], Scalar]

log_normalizer is logsumexp(log_weights).

Source code in smcjax/weights.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def log_normalize(
    log_weights: Float[Array, ' num_particles'],
) -> tuple[Float[Array, ' num_particles'], Scalar]:
    """Normalize log weights and return the log normalizing constant.

    Args:
        log_weights: Unnormalized log importance weights.

    Returns:
        A tuple ``(log_normalized, log_normalizer)`` where
        *log_normalized* has ``logsumexp == 0`` and
        *log_normalizer* is ``logsumexp(log_weights)``.
    """
    log_normalizer = jnp.logaddexp.reduce(log_weights)  # type: ignore[union-attr]
    log_normalized = log_weights - log_normalizer
    return log_normalized, log_normalizer

normalize(log_weights)

Exponentiate and normalize log weights.

Parameters:

Name Type Description Default
log_weights Float[Array, ' num_particles']

Unnormalized log importance weights.

required

Returns:

Type Description
Float[Array, ' num_particles']

Normalized weights that sum to one.

Source code in smcjax/weights.py
29
30
31
32
33
34
35
36
37
38
39
40
41
def normalize(
    log_weights: Float[Array, ' num_particles'],
) -> Float[Array, ' num_particles']:
    """Exponentiate and normalize log weights.

    Args:
        log_weights: Unnormalized log importance weights.

    Returns:
        Normalized weights that sum to one.
    """
    log_norm, _ = log_normalize(log_weights)
    return jnp.exp(log_norm)

param_weighted_mean(posterior)

Compute the weighted mean of parameter particles at each step.

Parameters:

Name Type Description Default
posterior LiuWestPosterior

Liu-West filter posterior output.

required

Returns:

Type Description
Float[Array, 'ntime param_dim']

Weighted parameter means, shape (ntime, param_dim).

Source code in smcjax/diagnostics.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def param_weighted_mean(
    posterior: LiuWestPosterior,
) -> Float[Array, 'ntime param_dim']:
    r"""Compute the weighted mean of parameter particles at each step.

    Args:
        posterior: Liu-West filter posterior output.

    Returns:
        Weighted parameter means, shape ``(ntime, param_dim)``.
    """
    return _weighted_mean_field(
        posterior.filtered_log_weights,
        posterior.filtered_params,
    )

param_weighted_quantile(posterior, q)

Compute weighted quantiles of parameter particles at each step.

Parameters:

Name Type Description Default
posterior LiuWestPosterior

Liu-West filter posterior output.

required
q Float[Array, ' num_quantiles']

Quantile levels in [0, 1], e.g. jnp.array([0.025, 0.975]) for a 95% credible interval.

required

Returns:

Type Description
Float[Array, 'ntime num_quantiles param_dim']

Weighted quantiles, shape (ntime, num_quantiles, param_dim).

Source code in smcjax/diagnostics.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def param_weighted_quantile(
    posterior: LiuWestPosterior,
    q: Float[Array, ' num_quantiles'],
) -> Float[Array, 'ntime num_quantiles param_dim']:
    r"""Compute weighted quantiles of parameter particles at each step.

    Args:
        posterior: Liu-West filter posterior output.
        q: Quantile levels in [0, 1], e.g. ``jnp.array([0.025, 0.975])``
            for a 95% credible interval.

    Returns:
        Weighted quantiles, shape ``(ntime, num_quantiles, param_dim)``.
    """
    return _weighted_quantile_field(
        posterior.filtered_log_weights,
        posterior.filtered_params,
        q,
    )

pareto_k_diagnostic(posterior)

Compute the Pareto-k diagnostic at each time step.

Fits a generalised Pareto distribution (GPD) to the upper tail of the importance weights using the Zhang and Stephens (2009) profile-likelihood estimator with the weakly informative prior from Vehtari, Simpson, Gelman, Yao, and Gabry (2024).

The shape parameter :math:\hat{k} indicates reliability:

  • :math:\hat{k} < 0.5: good, finite variance of the IS estimate
  • :math:0.5 \le \hat{k} < 0.7: marginal
  • :math:0.7 \le \hat{k} < 1.0: unreliable (infinite variance)
  • :math:\hat{k} \ge 1.0: very unreliable (infinite mean)

The tail size is ceil(min(0.2 * N, 3 * sqrt(N))) order statistics, matching the conventions of ArviZ and NumPyro.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required

Returns:

Type Description
Float[Array, ' ntime']

Per-step Pareto-k estimates, shape (ntime,).

Source code in smcjax/diagnostics.py
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def pareto_k_diagnostic(
    posterior: ParticleFilterResult,
) -> Float[Array, ' ntime']:
    r"""Compute the Pareto-k diagnostic at each time step.

    Fits a generalised Pareto distribution (GPD) to the upper tail
    of the importance weights using the Zhang and Stephens (2009)
    profile-likelihood estimator with the weakly informative prior
    from Vehtari, Simpson, Gelman, Yao, and Gabry (2024).

    The shape parameter :math:`\hat{k}` indicates reliability:

    - :math:`\hat{k} < 0.5`: good, finite variance of the IS estimate
    - :math:`0.5 \le \hat{k} < 0.7`: marginal
    - :math:`0.7 \le \hat{k} < 1.0`: unreliable (infinite variance)
    - :math:`\hat{k} \ge 1.0`: very unreliable (infinite mean)

    The tail size is ``ceil(min(0.2 * N, 3 * sqrt(N)))`` order
    statistics, matching the conventions of ArviZ and NumPyro.

    Args:
        posterior: Particle filter posterior output.

    Returns:
        Per-step Pareto-k estimates, shape ``(ntime,)``.
    """
    return vmap(_fit_pareto_k)(posterior.filtered_log_weights)

particle_diversity(posterior)

Compute the fraction of unique particles at each time step.

Particle diversity measures path degeneracy: a value near 1 means most particles are distinct, while near 0 means heavy duplication after resampling.

Uses an indicator-based method (not jnp.unique) for JIT compatibility: counts the fraction of particles that differ from their predecessor in the sorted order.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required

Returns:

Type Description
Float[Array, ' ntime']

Diversity fraction in [0, 1] at each time step,

Float[Array, ' ntime']

shape (ntime,).

Source code in smcjax/diagnostics.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def particle_diversity(
    posterior: ParticleFilterResult,
) -> Float[Array, ' ntime']:
    r"""Compute the fraction of unique particles at each time step.

    Particle diversity measures path degeneracy: a value near 1 means
    most particles are distinct, while near 0 means heavy duplication
    after resampling.

    Uses an indicator-based method (not ``jnp.unique``) for JIT
    compatibility: counts the fraction of particles that differ from
    their predecessor in the sorted order.

    Args:
        posterior: Particle filter posterior output.

    Returns:
        Diversity fraction in [0, 1] at each time step,
        shape ``(ntime,)``.
    """
    ancestors = posterior.ancestors  # (ntime, num_particles)
    num_particles = ancestors.shape[1]

    def _diversity_one_step(
        anc: Int[Array, ' num_particles'],
    ) -> Float[Array, '']:
        """Count fraction of unique ancestors at one time step."""
        sorted_anc = jnp.sort(anc)
        # First element is always unique; subsequent are unique if
        # different from predecessor
        is_unique = jnp.concatenate(
            [
                jnp.array([True]),
                sorted_anc[1:] != sorted_anc[:-1],
            ]
        )
        return jnp.sum(is_unique) / num_particles

    return vmap(_diversity_one_step)(ancestors)

posterior_predictive_sample(key, posterior, transition_sampler, emission_sampler, num_samples=None)

Draw one-step-ahead posterior predictive samples.

At each time step :math:t, we:

  1. Resample particle indices from the normalised weights.
  2. Propagate each resampled state through transition_sampler.
  3. Draw an emission from emission_sampler.

This gives iid samples from the posterior predictive :math:p(y_{t+1} \mid y_{1:t}), which can be compared with the actual observation :math:y_{t+1} for posterior predictive checking (Gelman et al., 2013, ch. 6).

Parameters:

Name Type Description Default
key PRNGKeyT

JAX PRNG key.

required
posterior ParticleFilterResult

Particle filter posterior output.

required
transition_sampler Callable

Function (key, state) -> state.

required
emission_sampler Callable

Function (key, state) -> emission.

required
num_samples int | None

Number of predictive draws per time step. Defaults to the number of particles.

None

Returns:

Type Description
Float[Array, 'ntime num_samples emission_dim']

Predictive samples, shape

Float[Array, 'ntime num_samples emission_dim']

(ntime, num_samples, emission_dim).

Source code in smcjax/diagnostics.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def posterior_predictive_sample(
    key: PRNGKeyT,
    posterior: ParticleFilterResult,
    transition_sampler: Callable,
    emission_sampler: Callable,
    num_samples: int | None = None,
) -> Float[Array, 'ntime num_samples emission_dim']:
    r"""Draw one-step-ahead posterior predictive samples.

    At each time step :math:`t`, we:

    1. Resample particle indices from the normalised weights.
    2. Propagate each resampled state through ``transition_sampler``.
    3. Draw an emission from ``emission_sampler``.

    This gives iid samples from the posterior predictive
    :math:`p(y_{t+1} \mid y_{1:t})`, which can be compared with
    the actual observation :math:`y_{t+1}` for posterior predictive
    checking (Gelman et al., 2013, ch. 6).

    Args:
        key: JAX PRNG key.
        posterior: Particle filter posterior output.
        transition_sampler: Function ``(key, state) -> state``.
        emission_sampler: Function ``(key, state) -> emission``.
        num_samples: Number of predictive draws per time step.
            Defaults to the number of particles.

    Returns:
        Predictive samples, shape
        ``(ntime, num_samples, emission_dim)``.
    """
    ntime, n_particles = posterior.filtered_log_weights.shape
    if num_samples is None:
        num_samples = n_particles

    def _sample_one_step(
        log_weights_t: Float[Array, ' num_particles'],
        particles_t: Float[Array, 'num_particles state_dim'],
        step_key: PRNGKeyT,
    ) -> Float[Array, 'num_samples emission_dim']:
        """Draw predictive samples at one time step."""
        k1, k2, k3 = jr.split(step_key, 3)
        weights = jnp.exp(log_weights_t - jnp.max(log_weights_t))
        weights = weights / jnp.sum(weights)
        indices = jr.choice(k1, n_particles, shape=(num_samples,), p=weights)
        resampled = particles_t[indices]
        # Propagate through transition
        trans_keys = jr.split(k2, num_samples)
        propagated = vmap(transition_sampler)(trans_keys, resampled)
        # Draw emissions
        emit_keys = jr.split(k3, num_samples)
        return vmap(emission_sampler)(emit_keys, propagated)

    step_keys = jr.split(key, ntime)
    return vmap(_sample_one_step)(
        posterior.filtered_log_weights,
        posterior.filtered_particles,
        step_keys,
    )

replicated_log_ml(key, filter_fn, num_replicates)

Run a particle filter multiple times to assess log-ML variability.

Uses :func:jax.vmap over PRNG keys for efficient parallel evaluation. The resulting distribution of log-ML estimates quantifies Monte Carlo uncertainty in the evidence.

Parameters:

Name Type Description Default
key PRNGKeyT

JAX PRNG key.

required
filter_fn Callable[[PRNGKeyT], Scalar]

Function (key) -> scalar that runs a particle filter and returns the marginal log-likelihood.

required
num_replicates int

Number of independent filter runs.

required

Returns:

Type Description
Float[Array, ' num_replicates']

Array of log-ML estimates, shape (num_replicates,).

Source code in smcjax/diagnostics.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def replicated_log_ml(
    key: PRNGKeyT,
    filter_fn: Callable[[PRNGKeyT], Scalar],
    num_replicates: int,
) -> Float[Array, ' num_replicates']:
    r"""Run a particle filter multiple times to assess log-ML variability.

    Uses :func:`jax.vmap` over PRNG keys for efficient parallel
    evaluation.  The resulting distribution of log-ML estimates
    quantifies Monte Carlo uncertainty in the evidence.

    Args:
        key: JAX PRNG key.
        filter_fn: Function ``(key) -> scalar`` that runs a particle
            filter and returns the marginal log-likelihood.
        num_replicates: Number of independent filter runs.

    Returns:
        Array of log-ML estimates, shape ``(num_replicates,)``.
    """
    keys = jr.split(key, num_replicates)
    return jnp.asarray(vmap(filter_fn)(keys))

tail_ess(posterior, q=0.05)

Compute tail effective sample size at each time step.

Tail-ESS measures how well the weighted particle approximation represents the tails of the distribution (Vehtari, Gelman, Simpson, Carpenter, and Burkner, 2020). We compute the ESS for the indicator function :math:I(w_i \ge w_{(q)}), i.e. how well the largest weights are distributed.

Specifically, for normalised weights :math:w_i we compute the ESS of the weights truncated below the :math:(1-q) quantile:

.. math::

\text{tail-ESS} = \frac{
    \bigl(\sum_{i : w_i \ge c} w_i \bigr)^2
}{
    \sum_{i : w_i \ge c} w_i^2
}

where :math:c is the :math:(1-q) quantile of the weights.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required
q float

Tail probability. Default 0.05, so we examine the top 5% weight mass.

0.05

Returns:

Type Description
Float[Array, ' ntime']

Tail-ESS at each time step, shape (ntime,).

Source code in smcjax/diagnostics.py
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
def tail_ess(
    posterior: ParticleFilterResult,
    q: float = 0.05,
) -> Float[Array, ' ntime']:
    r"""Compute tail effective sample size at each time step.

    Tail-ESS measures how well the weighted particle approximation
    represents the tails of the distribution (Vehtari, Gelman,
    Simpson, Carpenter, and Burkner, 2020).  We compute the ESS
    for the indicator function :math:`I(w_i \ge w_{(q)})`, i.e.
    how well the largest weights are distributed.

    Specifically, for normalised weights :math:`w_i` we compute
    the ESS of the weights truncated below the :math:`(1-q)`
    quantile:

    .. math::

        \text{tail-ESS} = \frac{
            \bigl(\sum_{i : w_i \ge c} w_i \bigr)^2
        }{
            \sum_{i : w_i \ge c} w_i^2
        }

    where :math:`c` is the :math:`(1-q)` quantile of the weights.

    Args:
        posterior: Particle filter posterior output.
        q: Tail probability.  Default 0.05, so we examine the top
            5% weight mass.

    Returns:
        Tail-ESS at each time step, shape ``(ntime,)``.
    """

    def _tail_ess_one_step(
        log_weights: Float[Array, ' num_particles'],
    ) -> Float[Array, '']:
        """Compute tail-ESS for one time step."""
        weights = jnp.exp(log_weights - jnp.max(log_weights))
        weights = weights / jnp.sum(weights)
        # Threshold: (1-q) quantile of weights
        sorted_w = jnp.sort(weights)
        cum_w = jnp.cumsum(sorted_w)
        threshold = sorted_w[jnp.searchsorted(cum_w, 1.0 - q)]
        # Tail weights
        tail_mask = weights >= threshold
        tail_w = jnp.where(tail_mask, weights, 0.0)
        sum_w = jnp.sum(tail_w)
        sum_w2 = jnp.sum(tail_w**2)
        return jnp.asarray(
            jnp.where(
                sum_w2 > 0.0,
                sum_w**2 / sum_w2,
                jnp.float64(0.0),
            )
        )

    return vmap(_tail_ess_one_step)(posterior.filtered_log_weights)

weighted_mean(posterior)

Compute the weighted mean of particles at each time step.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required

Returns:

Type Description
Float[Array, 'ntime state_dim']

Weighted means, shape (ntime, state_dim).

Source code in smcjax/diagnostics.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def weighted_mean(
    posterior: ParticleFilterResult,
) -> Float[Array, 'ntime state_dim']:
    r"""Compute the weighted mean of particles at each time step.

    Args:
        posterior: Particle filter posterior output.

    Returns:
        Weighted means, shape ``(ntime, state_dim)``.
    """
    return _weighted_mean_field(
        posterior.filtered_log_weights,
        posterior.filtered_particles,
    )

weighted_quantile(posterior, q)

Compute weighted quantiles of particles at each time step.

Uses a sorted resampling approach for JIT compatibility: sorts particles, computes cumulative weights, and interpolates.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required
q Float[Array, ' num_quantiles']

Quantile levels in [0, 1], e.g. jnp.array([0.025, 0.975]) for a 95% credible interval.

required

Returns:

Type Description
Float[Array, 'ntime num_quantiles state_dim']

Weighted quantiles, shape (ntime, num_quantiles, state_dim).

Source code in smcjax/diagnostics.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def weighted_quantile(
    posterior: ParticleFilterResult,
    q: Float[Array, ' num_quantiles'],
) -> Float[Array, 'ntime num_quantiles state_dim']:
    r"""Compute weighted quantiles of particles at each time step.

    Uses a sorted resampling approach for JIT compatibility:
    sorts particles, computes cumulative weights, and interpolates.

    Args:
        posterior: Particle filter posterior output.
        q: Quantile levels in [0, 1], e.g. ``jnp.array([0.025, 0.975])``
            for a 95% credible interval.

    Returns:
        Weighted quantiles, shape ``(ntime, num_quantiles, state_dim)``.
    """
    return _weighted_quantile_field(
        posterior.filtered_log_weights,
        posterior.filtered_particles,
        q,
    )

weighted_variance(posterior)

Compute the weighted variance of particles at each time step.

Uses the formula :math:V = \sum_i w_i (x_i - \mu)^2 where :math:\mu is the weighted mean.

Parameters:

Name Type Description Default
posterior ParticleFilterResult

Particle filter posterior output.

required

Returns:

Type Description
Float[Array, 'ntime state_dim']

Weighted variances, shape (ntime, state_dim).

Source code in smcjax/diagnostics.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def weighted_variance(
    posterior: ParticleFilterResult,
) -> Float[Array, 'ntime state_dim']:
    r"""Compute the weighted variance of particles at each time step.

    Uses the formula :math:`V = \sum_i w_i (x_i - \mu)^2` where
    :math:`\mu` is the weighted mean.

    Args:
        posterior: Particle filter posterior output.

    Returns:
        Weighted variances, shape ``(ntime, state_dim)``.
    """
    means = weighted_mean(posterior)
    deviations = posterior.filtered_particles - means[:, None, :]
    return _weighted_mean_field(
        posterior.filtered_log_weights,
        deviations**2,
    )