Skip to content

Action selection & goals

What the agent wants, and how it picks an action. StateGoal and ObservationGoal are the goals you hand the Agent (the continuous-state answer to pymdp's C); Preference and EFESelector are the machinery underneath expected-free-energy selection.

StateGoal dataclass

StateGoal(
    target: ArrayLike, *, precision=None, effort=None
)

A state-space objective: reach a target state (the LQR / fixed-sensor regime).

The complete spec for the state-tracking path - the target plus the LQR cost weights it implies. precision is LQR's state weight Q; effort is its action weight R, left None here because the action dimension p isn't known until the Agent pairs this with a model (the Agent fills the identity). The Agent dispatches a StateGoal to an LQRSelector. Not a pytree - construction- time only; the Agent extracts a Preference for the selector.

Source code in src/cpomdp/selection.py
def __init__(self, target: ArrayLike, *, precision=None, effort=None) -> None:
    target = jnp.asarray(target, dtype=float)
    object.__setattr__(self, "target", target)
    n = target.shape[0]
    object.__setattr__(
        self,
        "precision",
        jnp.eye(n) if precision is None else jnp.asarray(precision, dtype=float),
    )
    object.__setattr__(
        self,
        "effort",
        None if effort is None else jnp.asarray(effort, dtype=float),
    )
    self._validate()

ObservationGoal dataclass

ObservationGoal(
    target,
    action_bounds,
    *,
    precision=None,
    n_candidates=21,
    horizon=1,
)

An observation-space objective: prefer to observe a target (the EFE regime).

The complete spec for the information-seeking path - the preferred observation, how sharply it is preferred (precision), and the action-search config the EFESelector front-loads: action_bounds is the action box, n_candidates its resolution, horizon its lookahead depth. The Agent dispatches an ObservationGoal to an EFESelector. Not a pytree - construction-time only; the Agent extracts a Preference.

Source code in src/cpomdp/selection.py
def __init__(
    self, target, action_bounds, *, precision=None, n_candidates=21, horizon=1
) -> None:
    target = jnp.asarray(target, dtype=float)
    object.__setattr__(self, "target", target)
    m = target.shape[0]
    object.__setattr__(
        self,
        "precision",
        jnp.eye(m) if precision is None else jnp.asarray(precision, dtype=float),
    )
    object.__setattr__(self, "action_bounds", action_bounds)
    object.__setattr__(self, "n_candidates", n_candidates)
    object.__setattr__(self, "horizon", horizon)
    self._validate()

Preference dataclass

Preference(
    goal: ArrayLike, precision: ArrayLike | None = None
)

What the agent wants: a goal and how sharply it is preferred.

Single-mode for v0.3 — one Gaussian preference. The disjunctive mixture case (visit one of several goals) is RFC-002, deferred; this type is the seam that a mixture Preference plugs into.

precision is unused by LQRSelector (it is baked into the controller's Riccati solve at construction); it is carried here for the EFE pragmatic term added in Phase 1A.

Source code in src/cpomdp/selection.py
def __init__(self, goal: ArrayLike, precision: ArrayLike | None = None) -> None:
    goal = jnp.asarray(goal, dtype=float)
    object.__setattr__(self, "goal", goal)
    n = goal.shape[0]
    object.__setattr__(
        self,
        "precision",
        jnp.eye(n) if precision is None else jnp.asarray(precision, dtype=float),
    )
    self._validate()

tree_flatten

tree_flatten()

Leaves: (goal, precision); no static aux. Lets jit/vmap take a Preference.

Source code in src/cpomdp/selection.py
def tree_flatten(self):
    """Leaves: (goal, precision); no static aux. Lets jit/vmap take a Preference."""
    return (self.goal, self.precision), None

tree_unflatten classmethod

tree_unflatten(aux_data, children)

Rebuild without re-validating — the leaves may be tracers.

Source code in src/cpomdp/selection.py
@classmethod
def tree_unflatten(cls, aux_data, children):
    """Rebuild without re-validating — the leaves may be tracers."""
    goal, precision = children
    obj = object.__new__(cls)
    object.__setattr__(obj, "goal", goal)
    object.__setattr__(obj, "precision", precision)
    return obj

EFESelector

EFESelector(
    model: LinearGaussianModel,
    *,
    n_candidates: int,
    action_bounds: tuple[float, float],
    horizon: int = 1,
)

EFE action selection over a front-loaded candidate grid, horizon-aware.

At horizon = 1 (default) it minimises one-step G over the grid. At horizon > 1 it scores constant-action policies (each grid action held for H steps) via policy_efe and returns the first action of the best one (receding-horizon). Per-cycle cost is a single attributable number, cost_per_cycle = n_candidates * horizon.

Honest caveat: horizon selects the best constant action, not the best sequence. A genuinely sequential epistemic policy — move to sense, then exploit — needs a varying sequence the constant-action family cannot express, so at H > 1 the selector can still look myopic-ish on such tasks. True varying-sequence search is the deferred v0.4 GradientEFESelector seam.

Source code in src/cpomdp/selection.py
def __init__(
    self,
    model: LinearGaussianModel,
    *,
    n_candidates: int,
    action_bounds: tuple[float, float],
    horizon: int = 1,
) -> None:
    if model.control is None:
        raise ValueError(
            "EFESelector needs a model with a control matrix; an action has no "
            "effect on a control-free (pure-tracking) model."
        )
    p = model.control.shape[1]
    if p != 1:
        raise ValueError(
            f"EFESelector searches a 1-D action grid (p=1); got p={p}. "
            f"Multi-dimensional action search is the deferred v0.4 "
            f"GradientEFESelector seam — pass a custom selector for p>1."
        )
    lo, hi = action_bounds
    if not lo < hi:
        raise ValueError(
            f"action_bounds must be (lo, hi) with lo < hi, got {action_bounds}"
        )
    if n_candidates < 2:
        raise ValueError(
            f"n_candidates must be at least 2 to search, got {n_candidates}"
        )
    if horizon < 1:
        raise ValueError(f"horizon must be >= 1, got {horizon}")
    self._model = model
    self._horizon = horizon
    self._candidates = jnp.linspace(lo, hi, n_candidates)[:, None]

n_candidates property

n_candidates: int

The per-cycle EFE-evaluation count — attributable work (RFC-001).

horizon property

horizon: int

The lookahead depth — constant-action steps scored per candidate.

cost_per_cycle property

cost_per_cycle: int

Per-cycle step-evals = n_candidates * horizon.

select

select(
    belief: Belief, preference: Preference
) -> Float64[Array, p]

The grid action minimising G over the horizon (the per-cycle work).

At horizon = 1 one vmap of the one-step kernel + argmin. At horizon > 1 one vmap of policy_efe over the constant-action policies + argmin, returning the first (= constant) action of the best policy.

Source code in src/cpomdp/selection.py
def select(self, belief: Belief, preference: Preference) -> Float64[Array, "p"]:
    """The grid action minimising ``G`` over the horizon (the per-cycle work).

    At ``horizon = 1`` one ``vmap`` of the one-step kernel + ``argmin``. At
    ``horizon > 1`` one ``vmap`` of ``policy_efe`` over the constant-action policies
    + ``argmin``, returning the first (= constant) action of the best policy.
    """
    if self._horizon == 1:
        g = jax.vmap(
            lambda a: expected_free_energy(self._model, belief, a, preference)[0]
        )(self._candidates)
        return self._candidates[self._argmin(g)]
    # H>1: each candidate becomes a constant-action policy (held for H steps).
    policies = jnp.repeat(self._candidates[:, None, :], self._horizon, axis=1)
    g = jax.vmap(lambda pol: policy_efe(self._model, belief, pol, preference)[0])(
        policies
    )
    return self._candidates[self._argmin(g)]  # first (= constant) action