Skip to content

Backends

The inference engine is swappable behind the InferenceBackend protocol. KalmanBackend is the default fast path; RxInferBackend — imported from cpomdp.backends.rxinfer and gated behind the optional rxinfer extra — re-derives the same answers through Julia and exists as an independent correctness oracle [^bagaev2023rxinfer].

InferenceBackend

Bases: Protocol

A swappable inference engine for a LinearGaussianModel.

A backend is built from a model: any expensive, data-independent work (front-loading — see DECISIONS.md ADR-002) happens at construction, so the per-step infer_states stays cheap. Each call advances the belief one recursive filter step: the current belief goes in as the prior and the updated belief comes back as the posterior.

The Protocol is structural: any class with a matching infer_states is a backend, with no shared base class. This is the abstraction wall — the native Kalman fast path and the RxInfer oracle are interchangeable behind it, and neither's implementation (JAX, juliacall, …) leaks into this signature.

infer_states

infer_states(
    observation: ArrayLike,
    prior: Belief,
    action: ArrayLike | None = None,
) -> Belief

Advance the belief by one filter step: prior in, posterior out.

Given the current belief (prior) and a new observation (plus the action just taken, if the model has a control matrix), return the updated belief.

Source code in src/cpomdp/backends/base.py
def infer_states(
    self,
    observation: ArrayLike,
    prior: Belief,
    action: ArrayLike | None = None,
) -> Belief:
    """Advance the belief by one filter step: ``prior`` in, posterior out.

    Given the current belief (``prior``) and a new ``observation`` (plus the
    ``action`` just taken, if the model has a control matrix), return the
    updated belief.
    """
    ...

KalmanBackend

KalmanBackend(
    model: LinearGaussianModel,
    *,
    steady_state: bool = False,
    tol: float = 1e-12,
    max_iter: int = 1000,
)

Exact Kalman-filter inference for a LinearGaussianModel.

Implements the InferenceBackend protocol: constructed from a model, then advances a belief one step at a time (prior in, posterior out) via the standard predict/update recursion.

Two modes:

  • Per-step (default): recomputes the Kalman gain and covariance every step from the incoming belief. Correct for any linear-Gaussian model, including transient (pre-convergence) behaviour. This is the analytic oracle the rest of the toolbox is validated against.
  • Steady-state (steady_state=True): solves the covariance recursion once at construction to a fixed point, then reuses the frozen gain and covariance every step. Cheap (no per-step covariance maths), but only valid for time-invariant models with regular complete observations. Raises RuntimeError if the recursion does not converge within max_iter (i.e. the model is not stabilisable/detectable).

Parameters:

Name Type Description Default
model LinearGaussianModel

The linear-Gaussian generative model to filter under.

required
steady_state bool

If True, precompute and freeze the steady-state gain.

False
tol float

Convergence tolerance for the steady-state fixed point (absolute, on successive covariances).

1e-12
max_iter int

Cap on steady-state iterations before giving up.

1000
Source code in src/cpomdp/backends/kalman.py
def __init__(
    self,
    model: LinearGaussianModel,
    *,
    steady_state: bool = False,
    tol: float = 1e-12,
    max_iter: int = 1000,
) -> None:
    self.model = model
    self.steady_state = steady_state
    if steady_state:
        sensor_fixed = model.observation is None or model.observation.is_fixed
        process_fixed = model.process_noise is None or model.process_noise.is_fixed
        if not (sensor_fixed and process_fixed):
            raise ValueError(
                "steady_state=True needs fixed sensor and process noise; a "
                "state-dependent R(x) or Q(x) has no constant fixed point — "
                "use steady_state=False."
            )
        self._steady_gain, self._steady_cov = self._converge_to_steady_state(
            tol, max_iter
        )

infer_states

infer_states(
    observation: ArrayLike,
    prior: Belief,
    action: ArrayLike | None = None,
) -> Belief

Advance the belief by one filter step.

Runs one predict/update cycle: step the prior through the dynamics (applying action if the model has a control matrix), then correct the prediction toward observation using the Kalman gain. In steady-state mode the gain and covariance are the frozen fixed-point values; otherwise they are recomputed from prior.cov on this step.

The numeric work is delegated to the jit-compiled module kernels (_gain_and_posterior_cov, _posterior_mean); this method stays the eager orchestrator that validates inputs and wraps the result in a Belief.

Parameters:

Name Type Description Default
observation ArrayLike

The latest sensor reading, shape (m,).

required
prior Belief

The current belief, treated as this step's previous posterior. Never mutated.

required
action ArrayLike | None

The action just taken, shape (p,). Required iff the model has a control matrix; ignored (pass None) for pure filtering.

None

Returns:

Type Description
Belief

The posterior belief — a new Belief; the prior is left untouched.

Raises:

Type Description
ValueError

If observation is not shape (m,), prior is not a belief over the model's n-D state, the model has a control matrix but action is None, or action is not shape (p,). (All enforced in _validate_inputs.)

Source code in src/cpomdp/backends/kalman.py
def infer_states(
    self,
    observation: ArrayLike,
    prior: Belief,
    action: ArrayLike | None = None,
) -> Belief:
    """Advance the belief by one filter step.

    Runs one predict/update cycle: step the prior through the dynamics
    (applying ``action`` if the model has a control matrix), then correct the
    prediction toward ``observation`` using the Kalman gain. In steady-state
    mode the gain and covariance are the frozen fixed-point values; otherwise
    they are recomputed from ``prior.cov`` on this step.

    The numeric work is delegated to the ``jit``-compiled module kernels
    (``_gain_and_posterior_cov``, ``_posterior_mean``); this method stays the
    eager orchestrator that validates inputs and wraps the result in a
    ``Belief``.

    Args:
        observation: The latest sensor reading, shape ``(m,)``.
        prior: The current belief, treated as this step's previous posterior.
            Never mutated.
        action: The action just taken, shape ``(p,)``. Required iff the model
            has a control matrix; ignored (pass ``None``) for pure filtering.

    Returns:
        The posterior belief — a new ``Belief``; the prior is left untouched.

    Raises:
        ValueError: If ``observation`` is not shape ``(m,)``, ``prior`` is not
            a belief over the model's ``n``-D state, the model has a control
            matrix but ``action`` is ``None``, or ``action`` is not shape
            ``(p,)``. (All enforced in ``_validate_inputs``.)
    """
    model = self.model
    observation, action = validate_step_inputs(model, observation, prior, action)
    control = model.control
    if control is None:
        control_term = jnp.zeros(model.n_states)
    else:
        # validate_step_inputs guarantees a non-None action when control exists
        assert action is not None
        control_term = control @ action

    sensor_is_fixed = model.observation is None or model.observation.is_fixed
    process_is_fixed = model.process_noise is None or model.process_noise.is_fixed

    # μ⁻ is needed only to linearize a state-dependent sensor and/or process
    # noise; the fully-fixed hot path computes no extra matvec.
    mean_pred = (
        model.dynamics @ prior.mean + control_term
        if not (sensor_is_fixed and process_is_fixed)
        else prior.mean  # placeholder, unused on the fixed path
    )

    if sensor_is_fixed:
        # fixed sensor: direct reads, byte-identical hot path (no linearize).
        sensor_model, sensor_noise = model.sensor_model, model.sensor_noise
    else:
        # state-dependent R(x), linearized at μ⁻ (the EFE kernel's point).
        sensor_model, sensor_noise = model.observation.linearize(mean_pred)

    if process_is_fixed:
        dynamics_noise = model.dynamics_noise
    else:
        # state-dependent Q(x), evaluated at μ⁻ — the dual of the R(x) gate.
        dynamics_noise = model.process_noise.noise_at(mean_pred)

    if self.steady_state:
        gain, cov_post = self._steady_gain, self._steady_cov  # frozen
    else:
        gain, cov_post = _gain_and_posterior_cov(
            model.dynamics,
            sensor_model,
            dynamics_noise,
            sensor_noise,
            prior.cov,
        )

    mean_post = _posterior_mean(
        model.dynamics,
        sensor_model,
        prior.mean,
        control_term,
        gain,
        observation,
    )

    return Belief(mean=mean_post, cov=cov_post)

RxInferBackend

RxInferBackend(model: LinearGaussianModel)

Linear-Gaussian filtering via RxInfer.jl — the oracle backend.

Satisfies the InferenceBackend protocol: built from a model, advances a belief one step at a time. No steady-state mode — that belongs to the native fast path; this backend exists for correctness, not speed. The first instance built in a process loads the Julia runtime; later ones reuse it.

Source code in src/cpomdp/backends/rxinfer.py
def __init__(self, model: LinearGaussianModel) -> None:
    self.model = model
    self._jl = _julia()

infer_states

infer_states(
    observation: ArrayLike,
    prior: Belief,
    action: ArrayLike | None = None,
) -> Belief

Advance the belief one filter step: prior in, posterior out.

Parameters:

Name Type Description Default
observation ArrayLike

Latest sensor reading, shape (m,).

required
prior Belief

Current belief; never mutated.

required
action ArrayLike | None

Action just taken, shape (p,). Required iff the model has a control matrix; pass None for pure filtering.

None

Raises:

Type Description
ValueError

On a shape/None mismatch (see validate_step_inputs).

Source code in src/cpomdp/backends/rxinfer.py
def infer_states(
    self,
    observation: ArrayLike,
    prior: Belief,
    action: ArrayLike | None = None,
) -> Belief:
    """Advance the belief one filter step: prior in, posterior out.

    Args:
        observation: Latest sensor reading, shape ``(m,)``.
        prior: Current belief; never mutated.
        action: Action just taken, shape ``(p,)``. Required iff the model has
            a control matrix; pass ``None`` for pure filtering.

    Raises:
        ValueError: On a shape/None mismatch (see ``validate_step_inputs``).
    """
    model = self.model
    observation, action = validate_step_inputs(model, observation, prior, action)
    control = model.control
    if control is None:
        control_term = jnp.zeros(model.n_states)
    else:
        # validate_step_inputs guarantees a non-None action when control exists
        assert action is not None
        control_term = control @ action

    # juliacall speaks numpy, not jax.Array, so coerce every array as it
    # crosses into Julia and coerce the posteriors back on the way out.
    mean_post, cov_post = self._jl.cpomdp_run_step(
        np.asarray(observation),
        np.asarray(model.dynamics),
        np.asarray(model.sensor_model),
        np.asarray(model.dynamics_noise),
        np.asarray(model.sensor_noise),
        np.asarray(prior.mean),
        np.asarray(prior.cov),
        np.asarray(control_term),
    )

    return Belief(mean=np.asarray(mean_post), cov=np.asarray(cov_post))

[^bagaev2023rxinfer]: Dmitry Bagaev, Albert Podusenko, and Bert de Vries. Rxinfer: a julia package for reactive real-time bayesian inference. Journal of Open Source Software, 8(84):5161, 2023. URL: https://doi.org/10.21105/joss.05161, doi:10.21105/joss.05161.