Skip to content

Model & belief

The generative model the agent assumes, and the Gaussian belief it carries over the hidden state.

LinearGaussianModel dataclass

LinearGaussianModel(
    dynamics: ArrayLike,
    sensor_model: ArrayLike,
    dynamics_noise: ArrayLike,
    sensor_noise: ArrayLike,
    prior: Belief,
    control: ArrayLike | None = None,
    observation: ObservationModel | None = None,
    process_noise: DynamicsNoise | None = None,
    structure: ModelStructure | None = None,
)

A linear-Gaussian state-space model — the agent's generative model.

The agent's assumed story for how a hidden state evolves and produces observations, under linear maps and Gaussian noise::

next_state  = dynamics @ state + control @ action + dynamics noise
observation = sensor_model @ state               + sensor noise

The noise terms are zero-mean Gaussians with covariances dynamics_noise and sensor_noise; the initial state is drawn from prior.

Parameters are role-named rather than using the traditional control-theory letters, to avoid the letter collision with discrete active inference (pymdp), where the same letters mean different things. The "also known as" column lists the terms other backgrounds use, so readers can still find the right field. (Letters survive as .A/.B/.C/.Q/.R aliases for backend use.)

================ ====== ========================= ===== ==================== role name letter meaning shape also known as ================ ====== ========================= ===== ==================== dynamics A state -> next state (n,n) state-transition control B action -> state (optional) (n,p) input/control matrix sensor_model C state -> expected reading (m,n) observation/emission dynamics_noise Q dynamics-noise covariance (n,n) process noise sensor_noise R sensor-noise covariance (m,m) observation noise prior -- initial belief over state n-D Belief / D (pymdp) ================ ====== ========================= ===== ====================

Dimensions: n = state, m = observation, p = action. A model with no control is a pure filtering (tracking) model.

Three optional fields (all default None → the plain fixed-matrix model) extend it: observation (an :class:~cpomdp.observation.ObservationModel for state-dependent sensing R(x)), process_noise (a :class:~cpomdp.dynamics.DynamicsNoise for state-dependent process noise Q(x)), and structure (a :class:~cpomdp.structure.ModelStructure declaring the factor / Markov-blanket partition).

Source code in src/cpomdp/types.py
def __init__(
    self,
    dynamics: ArrayLike,
    sensor_model: ArrayLike,
    dynamics_noise: ArrayLike,
    sensor_noise: ArrayLike,
    prior: Belief,
    control: ArrayLike | None = None,
    observation: ObservationModel | None = None,
    process_noise: DynamicsNoise | None = None,
    structure: ModelStructure | None = None,
) -> None:
    object.__setattr__(self, "dynamics", jnp.asarray(dynamics, dtype=float))
    object.__setattr__(self, "sensor_model", jnp.asarray(sensor_model, dtype=float))
    object.__setattr__(
        self, "dynamics_noise", jnp.asarray(dynamics_noise, dtype=float)
    )
    object.__setattr__(self, "sensor_noise", jnp.asarray(sensor_noise, dtype=float))
    object.__setattr__(self, "prior", prior)
    object.__setattr__(
        self,
        "control",
        None if control is None else jnp.asarray(control, dtype=float),
    )
    object.__setattr__(self, "observation", observation)
    object.__setattr__(self, "process_noise", process_noise)
    object.__setattr__(self, "structure", structure)
    self._validate()

n_states property

n_states: int

Dimension of the hidden state (n).

n_observations property

n_observations: int

Dimension of an observation (m).

n_controls property

n_controls: int

Dimension of an action (p); 0 if the model has no control.

A property

A: Float64[Array, 'n n']

A: the state-transition matrix (alias of dynamics).

B property

B: Float64[Array, 'n p'] | None

B: the control matrix (alias of control); None if uncontrolled.

C property

C: Float64[Array, 'm n']

C: the observation matrix (alias of sensor_model).

Q property

Q: Float64[Array, 'n n']

Q: the process-noise covariance (alias of dynamics_noise).

R property

R: Float64[Array, 'm m']

R: the observation-noise covariance (alias of sensor_noise).

tree_flatten

tree_flatten() -> tuple[
    tuple[_ModelLeaf, ...], ModelStructure | None
]

Leaves for JAX: every matrix plus the prior belief; structure is aux.

control, observation and process_noise are included as (possibly None) children; an uncontrolled / fixed-sensor / fixed-Q model contributes no leaf there and the None is restored on rebuild. A non-None observation/process_noise is itself a pytree and recurses into its own leaves. structure (declarative metadata, no array leaves) rides in the static aux_data, so two models differing only in structure are different pytrees and a jit keyed on the model re-specialises when it changes.

Source code in src/cpomdp/types.py
def tree_flatten(self) -> tuple[tuple[_ModelLeaf, ...], ModelStructure | None]:
    """Leaves for JAX: every matrix plus the ``prior`` belief; ``structure`` is aux.

    ``control``, ``observation`` and ``process_noise`` are included as (possibly
    ``None``) children; an uncontrolled / fixed-sensor / fixed-Q model contributes
    no leaf there and the ``None`` is restored on rebuild. A non-``None``
    ``observation``/``process_noise`` is itself a pytree and recurses into its own
    leaves. ``structure`` (declarative metadata, no array leaves) rides in the
    static aux_data, so two models differing only in ``structure`` are different
    pytrees and a jit keyed on the model re-specialises when it changes.
    """
    children = (
        self.dynamics,
        self.sensor_model,
        self.dynamics_noise,
        self.sensor_noise,
        self.prior,
        self.control,
        self.observation,
        self.process_noise,
    )
    return children, self.structure

tree_unflatten classmethod

tree_unflatten(
    aux_data: ModelStructure | None,
    children: tuple[_ModelLeaf, ...],
) -> LinearGaussianModel

Rebuild from leaves without validating — the leaves may be tracers.

aux_data is the static structure (or None), restored as-is.

Source code in src/cpomdp/types.py
@classmethod
def tree_unflatten(
    cls,
    aux_data: ModelStructure | None,
    children: tuple[_ModelLeaf, ...],
) -> "LinearGaussianModel":
    """Rebuild from leaves without validating — the leaves may be tracers.

    ``aux_data`` is the static ``structure`` (or ``None``), restored as-is.
    """
    obj = object.__new__(cls)
    fields = (
        "dynamics",
        "sensor_model",
        "dynamics_noise",
        "sensor_noise",
        "prior",
        "control",
        "observation",
        "process_noise",
    )
    for name, value in zip(fields, children, strict=True):
        object.__setattr__(obj, name, value)
    object.__setattr__(obj, "structure", aux_data)
    return obj

Belief dataclass

Belief(mean: ArrayLike, cov: ArrayLike)

A Gaussian belief over a continuous state.

In active inference an agent never knows the hidden state directly — it holds a probability distribution over what the state might be. For the linear-Gaussian case that distribution is always a Gaussian, fully described by two things:

  • mean -- the centre, the best single estimate. A 1-D vector of length n.
  • cov -- the covariance, the uncertainty. An n x n matrix; its diagonal is the variance per state dimension, its off-diagonals the correlations between them.

Beliefs are immutable values: updating a belief produces a new Belief rather than mutating an existing one. Inputs are accepted as anything array-like (lists, tuples, arrays) and stored as float jax.Array.

A Belief is a registered JAX pytree (its leaves are mean and cov), so it passes through jit/vmap/grad as data. JAX rebuilds it from its leaves without re-running validation; the shape/symmetry checks fire only on direct construction, at the trust boundary. Positive-semi-definiteness is enforced at the trust boundary too, not here (see DECISIONS.md ADR-002).

Source code in src/cpomdp/types.py
def __init__(self, mean: ArrayLike, cov: ArrayLike) -> None:
    object.__setattr__(self, "mean", jnp.asarray(mean, dtype=float))
    object.__setattr__(self, "cov", jnp.asarray(cov, dtype=float))
    self._validate()

ndim property

ndim: int

Dimensionality of the state — the length of the mean vector.

tree_flatten

tree_flatten() -> tuple[
    tuple[Float64[Array, n], Float64[Array, "n n"]], None
]

Leaves for JAX: (mean, cov), no static aux data.

Source code in src/cpomdp/types.py
def tree_flatten(
    self,
) -> tuple[tuple[Float64[Array, "n"], Float64[Array, "n n"]], None]:
    """Leaves for JAX: ``(mean, cov)``, no static aux data."""
    return (self.mean, self.cov), None

tree_unflatten classmethod

tree_unflatten(
    aux_data: None,
    children: tuple[
        Float64[Array, n], Float64[Array, "n n"]
    ],
) -> Belief

Rebuild from leaves without validating — the leaves may be tracers.

Source code in src/cpomdp/types.py
@classmethod
def tree_unflatten(
    cls,
    aux_data: None,
    children: tuple[Float64[Array, "n"], Float64[Array, "n n"]],
) -> "Belief":
    """Rebuild from leaves without validating — the leaves may be tracers."""
    mean, cov = children
    obj = object.__new__(cls)
    object.__setattr__(obj, "mean", mean)
    object.__setattr__(obj, "cov", cov)
    return obj