Git Product home page Git Product logo

feedbax's People

Contributors

mlprt avatar

Stargazers

 avatar

Watchers

 avatar  avatar

feedbax's Issues

`ModelInput` and passing intervenor parameters to submodels

All Feedbax models (subclasses of AbstractModel) have the signature (input: ModelInput, state: StateT, *, key: PRNGKeyArray) -> StateT. StateT is bound to AbstractState in AbstractModel. StateT is now bound to PyTree[Array] -- see #24.

In AbstractStagedModel subclasses (#19) we perform a sequence of state operations by passing subsets of input and state to model components, and using the return values to make out-of-place updates to state.

An AbstractModel is generally a PyTree containing other AbstractModel nodes; i.e. Feedbax models are hierarchical. Typically, the outermost node in the model PyTree is an instance of Iterator, which is essentially a loop over a single step of the model (e.g. a SimpleFeedback instance) where all of the actual state operations happen.

The input to the outermost model node is not selected from the input to another AbstractModel that contains it, because there is none. Instead, its input is provided by an instance of AbstractTask. This task information is any trial-by-trial data that is unconditional on the internal operations of the model. For example, a reaching task like SimpleReaches will provide the model with the goal position it is expected to reach to, and the model will ultimately forward this to the controller (neural network) component.

An issue arises when we need to schedule interventions on a task/model that already exists. Interventions may change on a trial-by-trial basis. Any systematic trial-by-trial variations are specified by an AbstractTask. In particular, if the parameters of an AbstractIntervenor are expected to change across trials, then an AbstractTask should provide those changing parameters as part of the input to the model. The model will then need to make sure that these parameters are matched up to the right instance of AbstractIntervenor.

Perhaps there is a way for schedule_intervenor to work with AbstractTask to structure the intervention parameters so that, at each level of the model, AbstractStagedModel.__call__ can be made to send them on to the right component, until they reach the component that contains the instance of AbstractIntervenor they pertain to. I have not figured out how to do this.

My current solution is, when an intervenor is scheduled with schedule_intervenor, to assign it a unique string label among all the intervenors aggregated over all levels of a model PyTree. Then, intervention parameters are included in input as a flat mapping from the unique labels, to parameters. This flat mapping is passed as-is down through the hierarchy of model components; every AbstractStagedModel sees the same mapping, and simply tries to match the unique labels of its own intervenors, with those in the mapping.

This is what ModelInput is for: it's an eqx.Module with two fields, input and intervene: input contains the usual task information which, once it reaches the outermost AbstractStagedModel in the model, is selectively passed on to certain component(s) depending on the definition of model_spec (again, typically it's all sent to the neural network). On the other hand, intervene contains the flat mapping of intervention parameters, and is passed on as-is.

So, in AbstractStagedModel.__call__ we see something like:

feedbax/feedbax/_staged.py

Lines 152 to 160 in 8f080c6

callable_ = stage.callable(self)
subinput = stage.where_input(input.input, state)
# TODO: What's a less hacky way of doing this?
# I was trying to avoid introducing additional parameters to `AbstractStagedModel.__call__`
if isinstance(callable_, AbstractModel):
callable_input = ModelInput(subinput, input.intervene)
else:
callable_input = subinput

Here, we:

  1. We need to pass a subset of the model inputs to the current stage: select subinput out of input.input—I haven't thought of a better name. Maybe input.task_input.
  2. If the component to be called is an AbstractModel, it accepts ModelInput and might contain interventions. Therefore we pass a reconstructed ModelInput with the same intervene value (i.e. the flat mapping), but with only subinput as input.

This seems pretty hacky to me and I'm not sure how it should be done better. I've considered adding another argument to the signature of AbstractModel, but that doesn't seem better. Also, I suppose I don't have to use ModelInput at all, and could just type input as a tuple.

Optax `opt_state` changes shape on first training iteration, leading to recompilation

In feedbax.trainer.TaskTrainer, before starting the training loop, the training and validation methods are executed a single time in order to JIT compile them.

feedbax/feedbax/train.py

Lines 259 to 284 in 147fb42

# Finish the JIT compilation before the first training iteration.
if not jax.config.jax_disable_jit:
for _ in tqdm(range(1), desc="compile", disable=disable_tqdm):
if ensembled:
key_compile = jr.split(key, n_replicates)
else:
key_compile = key
train_step( # doesn't alter model or opt_state
task,
batch_size,
flat_model,
treedef_model,
flat_opt_state,
treedef_opt_state,
filter_spec,
key_compile,
)
if not disable_tqdm:
tqdm.write(f"Training step compiled.", file=sys.stdout)
evaluate(model, key_compile)
if not disable_tqdm:
tqdm.write(f"Validation step compiled.", file=sys.stdout)
else:
logger.debug("JIT globally disabled, skipping pre-run compilation")

Note that wrapping this in a tqdm loop was forquick way to time the compilation, though I ought to get rid of the trivial progress bar and just log the compilation time.

Having compiled these functions, I'd expected the training loop to warm up very quickly. However, the first step of the training loop is also slow. I suspect this is because train_loop is being recompiled, due to optimizer.update changing the shape of opt_state -- but only on the first iteration.

This could be solved by keeping the result of the JIT compilation/keeping the first two training iterations outside of the main progress bar. However, I'd still like an answer about why optimizer.update changes the shape of opt_state. It kind of makes sense that it would do so if it initializes its state lazily, and changes the state structure when (for example) it first encounters gradients.

Return validation losses in `TaskTrainerHistory`

At the end of a training run, TaskTrainer returns auxiliary information such as losses and intermediate model states in a TaskTrainerHistory object.

The history of the training loss is returned in the loss field, but the validation loss is not returned.

Validation is only performed once every log_step trials, so a little care will need to be taken that the returned arrays are comparable with those for the training loss; it might be appropriate to return arrays that are all NaN except on the log steps. Memory should not be an issue. Usually there won't be more than 5 loss terms or so, and (conservatively) less than 100,000 iterations. That's about 2 MB of float32.

Associate types of intervenors with types of staged models

Currently, to add a curl force field to a SimpleFeedback model, we need to write something like:

from feedbax.intervene import CurlField, add_intervenor

model_curl = add_intervenor(
    model, 
    CurlField.with_params(amplitude=-0.5),  # negative -> clockwise
    where=lambda m: m.step.mechanics,
)

However, the only part of model: SimpleFeedback to which it makes sense to add a CurlField, is model.step.mechanics. And if we were simulating a Mechanics instance directly instead of wrapping it in SimpleFeedback, it would make sense to add the intervention to model.step.

In principle Feedbax could recognize this, and not require that we specify it.

model_curl = add_intervenor(
    model, 
    CurlField.with_params(amplitude=-0.5),
)

A potential solution: for each type of intervenor, type it by (or assign it with) the subclass of AbstractStagedModel that it makes sense to add it to. When add_intervenor is called, we can automatically figure out where that model type lives in the tree, such as by using equinox.tree_at.

The outcome of this operation is ambiguous, in the case that the model PyTree contains multiple nodes of the same type of AbstractStagedModel, or in the case of general interventions like AddNoise whose associated type should be AbstractStagedModel itself. So there would also need to be some mechanism to determine when there are multiple instances to which the intervention could be added, and perhaps an argument to add_intervenor that would determine whether the intervention should be added to all of them, or an error should be raised if a disambiguating where has not been passed.

Typing `ModelStage`

State operations ("stages") in an AbstractStagedModel are defined by the property model_spec which is an OrderedDict[str, ModelStage].

For example, here's one entry from SimpleFeedback.model_spec:

feedbax/feedbax/bodies.py

Lines 195 to 199 in 2ce8b1c

"mechanics_step": ModelStage(
callable=lambda self: self.mechanics,
where_input=lambda input, state: state.net.output,
where_state=lambda state: state.mechanics,
),

The state arguments in these lambdas should be typed as SimpleFeedbackState, so that the type checker recognizes that state.mechanics is a valid reference: SimpleFeedbackState has a field mechanics: MechanicsState.

Currently, ModelStage is a generic of the type variable StateT = TypeVar('StateT', AbstractState, Array), where all of the state PyTrees like SimpleFeedbackState inherit from AbstractState. Here is a slight simplification:

StateT = TypeVar('StateT', AbstractState, Array)

class ModelStage(eqx.Module, Generic[StateT]):
    callable: Callable[[AbstractStagedModel[StateT]], Callable]
    where_input: Callable[[AbstractTaskInputs, StateT], PyTree]
    where_state: Callable[[StateT], PyTree]
    intervenors: Sequence[AbstractIntervenor] = field(default_factory=tuple)

However, nowhere do we subclass ModelStage and give an argument for this type variable.

Throughout Feedbax, Pyright raises errors for only some of the references found in lambdas of model_spec properties. For example, it raises an error for the callable field of the "mechanics_step" stage given at the start of this issue,

feedbax/bodies.py:196:48 - error: Cannot access member "mechanics" for type "AbstractStagedModel[Unknown]"
    Member "mechanics" is unknown (reportAttributeAccessIssue)

but not for its where_input or where_state fields. Checking the Pylance tooltips, the input arguments are typed as AbstractTaskInputs, but the state fields are Any, which is at odds with the StateT annotation in ModelStage.

I suspect it will be necessary (#24 ) to eliminate AbstractState. Then, the type arguments to final subclasses of AbstractStagedModel will need to be distinct subclasses of equinox.Module. Similarly, the type variable in ModelStage will need to be bound to equinox.Module instead of AbstractState. This doesn't seem problematic as ModelStage will only be used inside an AbstractStagedModel with which it shares a type argument. For example, ModelState[SimpleFeedbackState] will be used in SimpleFeedback[SimpleFeedbackState].

I've tried explicitly writing "mechanics_step": ModelStage[SimpleFeedbackState](...) in the model_spec entries. Writing type arguments in at the time of instantiation is not a syntax I've ever used before, and it doesn't resolve the type checker's errors. I wouldn't expect this to resolve the error raised for the callable field of ModelStage, since the type argument won't specify that self should be of type SimpleFeedback.

How should this be properly typed?

Intervening on model parameters

It's nice that we can intervene on a model's state PyTree, which in Feedbax is an equinox.Module object that is modified out-of-place by the model itself, which is an instance of AbstractStagedModel. However, sometimes we want to modify the parameters of the model object—such as the weights in a neural network. One such use case is online learning, where we want to update the parameters of the model within a single trial (between time steps).

Why an AbstractIntervenor can't intervene on model parameters

Intervenors added to an AbstractStagedModel cannot modify its fields. In JAX/Equinox we treat instances as immutable and update them out-of-place; this cannot be done from within the instance itself.

One option is to move the AbstractStagedModel fields we want to modify online into the model's respective AbstractState class, so that they "become" state variables. I'm not sure that's wise:

  1. Two different subclasses of AbstractStagedModel can depend on the same subclass of AbstractState. But they might not share parameters, so moving their parameters into AbstractState might lead to a proliferation of AbstractState subclasses.
  2. Model classes that refer to parameters stored all together in a field are less readable—we refer to self.params.x instead of self.x.

On the other hand, keeping model parameters in a separate PyTree is the strategy used for subclasses of AbstractIntervenor, so their instances can be controlled on a trial-by-trial basis by model inputs provided by an AbstractTask. I've accepted the downsides in this case, because AbstractIntervenor subclasses tend to be have small implementations and do not have "state" of their own, aside from their parameters.


Currently, Feedbax only allows offline updates to models: after a batch of trials is evaluated in TaskTrainer, we update the model by gradient descent with respect to the loss function. We also allow the user to specify additional functions (in the field model_update_funcs) which perform surgery on the model, given the evaluated states of the model for the current batch. Note that TaskTrainer has no access to the model/states on different time steps, only the state history following evaluation of the entire batch.

Given the current class structure of Feedbax, the logical place to intervene on model parameters is from within AbstractIterator objects, which iterate models over time.

I think this should be relatively straightforward, though it will require at least one significant modification: AbstractIterator currently has a field _step to which an instance of AbstractModel—the model which is iterated—is assigned. This field should not be a field. Instead, step should be passed as an argument to the __call__ method of the subclass of AbstractIterator. This way, the model step can be altered out-of-place on each time step of the trial, then returned in its final form—potentially along with its entire history.

I will probably try to implement this soon.

Note that this solution would only allow updates to parameters once per time step. This seems sufficient to me. In principle, turning model parameters into AbstractState (or similar) fields is a more general solution, since we could intervene on those fields at arbitrary points during execution of a single step of the model. However, I doubt that is necessary.

Include muscle activation dynamics in the `AbstractMuscle`, not the `AbstractPlant` instance

See #29 for an review of the structure of AbstractPlant, and how dynamical updates are aggregated. I'll try to summarize the relevant aspects, here.

MuscledArm has a field activator, which is an AbstractDynamicalSystem that describes how a muscle's activation changes with its input. For example, we could use a first-order filter to smooth out the input to the muscle, as an approximation of the dynamic of calcium diffusion through a muscle fibre in response to activity at a neuromuscular junction.

The activator component is included in the dynamics_spec of MuscledArm, which is a subclass AbstractPlant; all the components defined in dynamics_spec are automatically aggregated into a single vector_field method that is used as the ODETerm to a Diffrax solver.

Conceptually, it makes more sense for the muscle activation dynamics to be part of the muscle model -- i.e. AbstractMuscle. However, AbstractMuscle is only a subclass of AbstractStagedModel, and only defines a sequence of kinematic updates, and no vector fields.

To fix this, I think AbstractMuscle will have to subclass AbstractDynamicalSystem as well as AbstractStagedModel, similarly to the inheritance pattern already used by AbstractPlant.

This should reduce confusion between "muscle input" and "muscle activation". According to the above description, it's activation that's passed as the input argument to an AbstractMuscle. In some cases however, we might omit the activation filter, and pass a neural network's output directly through AbstractPlant, to AbstractMuscle.

Pretty printing of inputs and outputs of model stages

Normally lambdas are used for the where_input and where_state fields of ModelStage objects, which are specified inside of the model_spec of an AbstractStagedModel.

Because lambdas do not have nice string representations, when we print an instance's model_spec, we don't learn anything about its stages but their names. For example, here's a printout of model_spec for an instance of SimpleStagedNetwork:

OrderedDict([('hidden',
              ModelStage(
                callable=<function <lambda>>,
                where_input=<function <lambda>>,
                where_state=<function <lambda>>,
                intervenors=None
              )),
             ('readout',
              ModelStage(
                callable=<function <lambda>>,
                where_input=<function <lambda>>,
                where_state=<function <lambda>>,
                intervenors=None
              )),
             ('out_nonlinearity',
              ModelStage(
                callable=<function <lambda>>,
                where_input=<function <lambda>>,
                where_state=<function <lambda>>,
                intervenors=None
              ))])

The function feedbax.pprint_model_spec is a little better, since it provides info about the identity of the callable:

hidden: GRUCell
readout: wrapped: Linear
out_nonlinearity: wrapped: identity_func

Similarly to what's done with WhereDict (#14) it would be nice to parse the lambdas where_input and where_state so we can pretty print the references they contain. In this case, that could look something like

hidden: GRUCell(<lambda>, state.hidden) -> state.hidden
readout: wrapped: Linear(state.hidden) -> state.output
out_nonlinearity: wrapped: identity_func(state.output) -> state.output

In this case I've left in <lambda> for the input to GRUCell, which is ravel_pytree(input)[0]. Of course in case like this where the lambdas aren't just references to parts of the input/state, it's less clear how they should be included in the printout.

It may be unnecessary to provide this information to the user in this form, since they could just refer to the source of model_spec. One downside of expecting the user to read the source is that some model_spec definitions may use arbitrary logic to decide whether to include certain stages or not, which the user would need to parse. For example, in SimpleStagedNetwork:

feedbax/feedbax/nn.py

Lines 234 to 330 in 2ce8b1c

@property
def model_spec(self) -> OrderedDict[str, ModelStage]:
"""Specifies the network model stages: layers, nonlinearities, and noise.
Only includes stages for the encoding layer, readout layer, hidden noise, and
hidden nonlinearity, if the user respectively requests them at the time of
construction.
!!! NOTE
Inspects the instantiated hidden layer to determine if it is a stateful
network (e.g. an RNN). If not (e.g. Linear), it wraps the layer so that
it plays well with the state-passing of `AbstractStagedModel`. This assumes
that stateful layers will take 2 positional arguments, and stateless layers
only 1.
"""
if n_positional_args(self.hidden) == 1:
hidden_module = lambda self: wrap_stateless_callable(self.hidden)
if isinstance(self.hidden, eqx.nn.Linear):
logger.warning(
"Network hidden layer is linear but no hidden "
"nonlinearity is defined"
)
else:
# #TODO: revert this!
# def tmp(self):
# def wrapper(input, state, *, key):
# return self.hidden(input, jnp.zeros_like(state))
# return wrapper
# hidden_module = lambda self: tmp(self)
hidden_module = lambda self: self.hidden
if self.encoder is None:
spec = OrderedDict(
{
"hidden": ModelStage(
callable=hidden_module,
where_input=lambda input, _: ravel_pytree(input)[0],
where_state=lambda state: state.hidden,
),
}
)
else:
spec = OrderedDict(
{
"encoder": ModelStage(
callable=lambda self: lambda input, state, *, key: self.encoder(
input
),
where_input=lambda input, _: ravel_pytree(input)[0],
where_state=lambda state: state.encoding,
),
"hidden": ModelStage(
callable=hidden_module,
where_input=lambda input, state: state.encoding,
where_state=lambda state: state.hidden,
),
}
)
if self.hidden_nonlinearity is not None:
spec |= {
"hidden_nonlinearity": ModelStage(
callable=lambda self: wrap_stateless_callable(
self.hidden_nonlinearity, pass_key=False
),
where_input=lambda input, state: state.hidden,
where_state=lambda state: state.hidden,
),
}
if self.hidden_noise_std is not None:
spec |= {
"hidden_noise": ModelStage(
callable=lambda self: self._add_hidden_noise,
where_input=lambda input, state: state.hidden,
where_state=lambda state: state.hidden,
),
}
if self.readout is not None:
spec |= {
"readout": ModelStage(
callable=lambda self: wrap_stateless_callable(self.readout),
where_input=lambda input, state: state.hidden,
where_state=lambda state: state.output,
),
"out_nonlinearity": ModelStage(
callable=lambda self: wrap_stateless_callable(
self.out_nonlinearity, pass_key=False
),
where_input=lambda input, state: state.output,
where_state=lambda state: state.output,
),
}
return spec

On the other hand, a function like pprint_model_spec shows exactly the components in a model instance, briefly, in the order they are actually included.

Include any neural network as a component of a staged model

It should be straightforward to allow any Equinox-based neural network to be component of an AbstractStagedModel, such as SimpleFeedback, and be called during one of the model stages.

We would need to:

  1. if it doesn't take a state/hidden argument (e.g. not an RNN), wrap it to ignore the state argument that AbstractStagedModel.__call__ will try to pass it,
  2. associate it with a generic NetworkState-like PyTree that has a single leaf, which stores the output of the module. Alternatively, keep a single Array in the PyTree of any model of which the network is a component (e.g. SimpleFeedback).

On the other hand, we could specify the network itself as an AbstractStagedModel where network layers correspond to distinct stages, and where the activities of different layers may be kept as part of the state. In that case, the user can add interventions to these states without needing to redesign the model.

SimpleStagedNetwork is the prototype for a neural network AbstractStagedModel.

Eliminate `AbstractTaskTrialSpec`

Currently, typing in feedbax.task is a mess. A source of this mess is AbstractTaskTrialSpec.

feedbax/feedbax/task.py

Lines 150 to 167 in 1c239e6

class AbstractTaskTrialSpec(Module):
"""Abstract base class for trial specifications provided by a task.
Attributes:
inits: A mapping from `lambdas` that select model substates to be
initialized, to substates to initialize them with.
inputs: A PyTree of inputs to the model.
target: A PyTree of target states.
intervene: A mapping from unique intervenor names, to per-trial
intervention parameters.
"""
inits: AbstractVar[WhereDict]
# inits: OrderedDict[Callable[[AbstractState], PyTree[Array]],
# PyTree[Array]]
inputs: AbstractVar[AbstractTaskInputs]
target: AbstractVar[PyTree[Array]]
intervene: AbstractVar[Mapping[str, Array]]

Once (#10) the target field is specified in a generalized way similarly to inits, the only field that will vary systematically between tasks will be inputs. In that case I suspect it will make more sense to eliminate subclassing of AbstractTaskTrialSpec, and instead define it as a generic final class, something like:

InputT = TypeVar("InputT", Module, Array)

class TaskTrialSpec(Module, Generic[InputsT]):
    inits: WhereDict
    inputs: InputsT
    targets: WhereDict
    intervene: Mapping[str, Array]

This would also save developers from explicitly including the intervene field in subclasses of AbstractTaskTrialSpec, because subclassing would be unnecessary; see #20.

I am still not sure how (if at all) to explicitly associate the structure of a given type of inputs, and the type of model that is compatible with a task. In principle, it could involve multiple fields of TaskTrialSpec; see #14.

Typing of scalar derivatives of state variables

feedbax.dynamics.AbstractDynamicalSystem is a generic of StateT. Each subclass must define a vector field that returns the time derivatives of each array in a state PyTree.

def vector_field(
    self, 
    t: float, 
    state: StateT, 
    input: PyTree[Array],  
) -> StateT:
    """Returns the time derivatives of the system's states."""
    ...

The return type StateT is more or less correct: the return value will have the same PyTree structure as state, and -- because the time derivative is scalar -- the same array shapes as well.

This seems fine to me, but perhaps there is a better way to make it explicit when we are passing around derivatives.

Interface to model objects, and `Iterator`

This issue is about converting single model steps to iterated models, and how this affects the model's PyTree structure, and references made to its components.

Generally, models (such as SimpleFeedback) are defined as a single iteration, and then wrapped in an Iterator object -- which is more or less a jax.lax loop.

However:

  • Almost the entire model PyTree is under model.step.*, from the user's perspective.
    • For example, when they pass a where_train when calling a TaskTrainer, they generally need to specify it like lambda model: model.step.net.
    • Similarly, whenever performing model surgery or the like, most references will be to model.step.*.
    • This differs from the structure of the state PyTree. For example, we have model.step.net but states.net. This is because Iterator adds a time dimension to the arrays in states.
  • In certain cases, we might have a model whose top level is not an Iterator, but which we will try to interact with using code that refers to model.step.
    • Currently, all AbstractModels provide a step property, which trivially returns self when the model is not an Iterator. AbstractIterator instead returns self._step, which is the field that the rest of the model PyTree is actually assigned to.
    • Should we be stricter about types, and (say) always assume that TaskTrainer is passed a model wrapped in Iterator?
  • In TaskTrainer._train_step we have to get initial states for the model, for all trials in a batch.
    • We start by vmapping model.init to obtain a default state. Since the input state to an iterated model is the same as to the model step, Iterator.init just returns self.step.init.
    • After _train_step obtains this default state, it modifies parts of it using state initialization data provided for the current batch of training trials (by the AbstractTask object). Then, it is necessary to ask the model to make sure that the state is internally consistent.
      • For example, the AbstractTask will typically give an initial position for the effector (e.g. arm endpoint). From the effector position we need to infer and update the mechanical configuration (e.g. joint angles). This is only necessary prior to the first time step for the trials in the batch, after which the states will be internally consistent by virtue of the model's operations. Thus we have a method AbstractStagedModel.state_consistency_update which is called once in _train_step.
      • So, do we add def _state_consistency_update(self): return self.step.state_consistency_update to Iterator similarly to what we've done with init? Currently, _train_step calls model._step.state_consistency_update.

I have considered modifying TaskTrainer to handle the model iteration over time, so the user does not explicitly instantiate an Iterator, and can refer to model.* instead of model.step.*. This would make sense in light of AbstractTask providing model inputs as trajectories over time, which Iterator indexes from using tree_take -- such that it does not make sense to use a non-iterated model with TaskTrainer. Should this change be adopted?

Jaxtyping and PyTrees with mixed array and non-array leaves

Consider the function feedbax.tree_set:

feedbax/feedbax/_tree.py

Lines 103 to 133 in 147fb42

def tree_set(
tree: PyTree[Any | Shaped[Array, "batch *?dims"], "T"],
items: PyTree[Any | Shaped[Array, "*?dims"], "T"],
idx: int,
) -> PyTree[Any | Shaped[Array, "batch *?dims"], "T"]:
"""Perform an out-of-place update of each array leaf of a PyTree.
Non-array leaves are simply replaced by their matching leaves in `items`.
For example, if `tree` is a PyTree of states over time, whose first dimension
is the time step, and `items` is a PyTree of states for a single time step,
this function can be used to insert the latter into the former at a given time index.
Arguments:
tree: Any PyTree whose array leaves share a first dimension of the same
length, for example a batch dimension.
items: Any PyTree with the same structure as `tree`, and whose array
leaves have the same shape as the corresponding leaves in `tree`,
but lacking the first dimension.
idx: The index along the first dimension of the array leaves of `tree`
into which to insert the array leaves of `items`.
Returns:
A PyTree with the same structure as `tree`, where the array leaves of `items` have been inserted as the `idx`-th elements of the corresponding array leaves of `tree`.
"""
arrays = eqx.filter(tree, eqx.is_array)
vals_update, other_update = eqx.partition(
items, jax.tree_map(lambda x: x is not None, arrays)
)
arrays_update = jax.tree_map(lambda xs, x: xs.at[idx].set(x), arrays, vals_update)
return eqx.combine(arrays_update, other_update)

It takes 1) PyTree of array and non-array leaves, where the array leaves all share a batch dimension, 2) a PyTree of the same structure, but where the array leaves all lack the batch dimension, 3) an index into the batch dimension. It returns a copy of (1) where the arrays of (2) have been inserted at (3), for all the array leaves.

The Any is included to allow for non-array leaves. The problem is that Any | Array is equivalent to Any, so that the jaxtyping PyTree/Array annotations will never lead to errors.

Is there a way to do array shape checking with jaxtyping, while allowing for non-array leaves?

General design concerns with `AbstractStagedModel` and `AbstractIntervenor`

A major design motivation for Feedbax is the common use case where a researcher wants to intervene on an existing optimal control experiment. In this issue, I describe the approach I've taken to this problem, and my uncertainties about it.

Currently, Feedbax implements the following solution: models are defined by Equinox modules of type AbstractStagedModel. Each type of model is treated as a series of operations performed on a shared PyTree of states. All state operations (AKA stages) are defined in a consistent way, each as a collection of three things: 1) a model component to be called, 2) a function that selects the subset of model inputs/states to pass to the component, and 3) a function that selects the subset of the state that the component returns/updates.

To define a new staged model, we subclass AbstractStagedModel and implement the property model_spec, where those three things are defined for each of the model's stages. AbstractStagedModel implements __call__ itself, to perform the state operations defined in model_spec. For a more in depth description, see the documentation.

  • What kind of PyTree is model_spec?

    Currently, model_spec is defined as property of type OrderedDict[str, ModelStage]. We use a mapping because it's nice for the stages to have names which can be referred to by the user. However, we cannot use a dict, because—while its entries maintain their insertion order since Python 3.7—its keys get sorted during PyTree flatten/uflatten operations. OrderedDict doesn't have the same problem.

    [ModelStage] is an Equinox module whose fields describe the "three things" that define a stage. Using a module rather than (say) a tuple, makes model_spec a little more readable. However, there have been some typing issues with ModelState: #23.

  • Model state objects

    AbstractStagedModel is generic, and each of its final subclasses has a type argument that's some final subclass of equinox.Module. This is the type of state PyTree operated on by the model. Different staged models may operate on the same type of state object.

    A subclass of AbstractStagedModel may be composed of other types of AbstractStagedModel, in which case the state PyTrees associated with the higher-level model tend to be composites of the state PyTrees associated with the components.

    To subclass AbstractStagedModel we also have to implement an init method which takes a key, and returns a default instance of the model's state PyTree. I refer to this as "default state" and not "initial state" to distinguish it from the state that has been updated (e.g. placing the arm at its starting position) at the beginning of a trial, based on the specifications provided by a task. See the documentation for a description of how these initial states are specified.

Having defined the model's computation as a series of state operations, the user can now insert interventions between the stages of an existing model, without needing to alter its source code. How?

All subclasses of AbstractStagedModel must include (#20) a field interventions: Mapping[str, Sequence[AbstractIntervenor]], which maps from the names of model stages, to one or more instances of AbstractIntervenor. By performing surgery on this field, we can modify an existing model with interventions. AbstractStagedModel.__call__ automatically interleaves the state operations defined in intervenors, with those in model_spec.

For more on what a subclass of AbstractIntervenor looks like, see the docs.

What issues might there be with this approach?

  • Writing model_spec instead of an imperative __call__ is probably a little confusing, at first.
  • Any model that we want to intervene on using an AbstractIntervenor, we have to (re)write as an AbstractStagedModel with a model_spec. All the states we expect the user might want to intervene on, must be included as fields in the respective state PyTree.
    • It is possible to use non-staged/vanilla Equinox modules as components, but they will essentially be black boxes that transform one part of their owner's state into another part, without including any of their own internal states in the composite PyTree of states. See #2.
  • It's not obvious how to store the history of the outputs of an intervenor. An extra field intervenors: dict[str, PyTree[Array]] could be added to the state PyTree of the model it belongs to, into which intervenor "states" could be inserted... however this might lead to issues with inconsistent PyTree structure.

Is there some other solution that would allow users to insert interventions into arbitrary points in a model, without needing to modify the model's source at intervention time? Perhaps there is a solution with hooks/callbacks that could work, especially if our models were stateful objects like they might be in PyTorch, and if we didn't need to pass around state PyTrees. But I'm not sure a solution like that is desirable in a JAX library, or what it would look like.

Returning now to the general design philosophy. Consider that in principle, there need only be a single, final class StagedModel that has a single, trivial model stage that on its own does nothing to the state. Any potential subclass of AbstractStagedModel we might want to build, could be replaced by a constructor that returns instances of this hypothetical StagedModel, but with an appropriate sequence of interventions inserted before each instance's single stage. That is, interventions and model stages both define operations on a model's state, and in principle they are interchangeable, though they are (currently) represented differently.

So, which state operations do we include in a model to begin with, and which do we leave to potentially be defined as interventions? That's an important tradeoff our approach leaves us with. I suspect there's no avoiding that problem—no free lunch. The people designing models will always need to rely on their domain expertise not to presume too much, or too little.

Remove `feedbax.channel.ChannelSpec`?

On its instantiation, feedbax.channel.Channel is passed an input_proto -- a PyTree with the same type (structure & array shapes) as the PyTree the channel will queue. It would be inconvenient to ask the user to provide input_proto to construct their model, since it can be inferred from the rest of the model/state structure. What they should provide is a where lambda that picks out the part of the state that will be sent through the channel.

Currently, ChannelSpec is used for specifying a Channel instance, without constructing it. It mostly replicates the field structure of Channel, but instead of storing an input_proto it stores a where lambda.

class ChannelSpec(eqx.Module):
"""Specifies how to build a [`Channel`][feedbax.channel.Channel], with respect to the state PyTree of its owner.
Attributes:
where: A function that selects the subtree of feedback states.
delay: The number of previous inputs to store in the queue.
noise_std: The standard deviation of the noise to add to the output.
"""
where: Callable[[AbstractState], PyTree[Array]]
delay: int = 0
noise_std: Optional[float] = None

Currently, only SimpleFeedback implements the conversion of ChannelSpec to Channel, by assuming that the where lambda refers specifically to a part of SimpleFeedbackState.

This whole setup seems hacky and I'm trying to find a better way.

A first though is that Channel could have a where field similarly to ChannelSpec, and no field input_proto. In that case, the user could just construct and pass a Channel instead of a ChannelSpec, when constructing SimpleFeedback. During the execution of its stages, SimpleFeedback would pass its entire state to the Channel, whose where would pick out the substate internally.

However:

  • passing the entire SimpleFeedbackState to Channel is at odds with the way where_state should be specified in AbstractStagedModel.model_spec: we'd have to do model surgery from within Channel, to replace the ChannelState component of SimpleFeedbackState, and return an entire SimpleFeedbackState. Then is would not be clear from SimpleFeedback.model_spec, which substate of SimpleFeedbackState is actually relevant to Channel.
  • Similarly, we would have to define the generic type argument of Channel as SimpleFeedbackState to have the signature of its __call__ method agree with AbstractStagedModel.
  • Channel.init cannot provide a valid initial ChannelState, without access to an actual example of the subtree that's selected by where. In the current situation, the subtree example is provided by the Channel's self.input_proto. We do not want to pass input_proto to init as an argument, since all init methods are designed as providing a default state.

The suitability of `WhereDict`: lambdas as keys

AbstractTask provides data with which to initialize subsets of the model state, at the start of each task trial.

This could be accomplished by providing a PyTree with the same structure as the full model state, with None at all leaves except those to be initialized, and then using eqx.combine to replace the model substates with initial values provided. However, this would require that each type of AbstractTask be associated with a particular type of state PyTree, whereas in principle a task type should be compatible with any model whose associated state contains at least the substates 1) to be initialized and 2) that are targets/part of the loss computation. For example, it shouldn't matter how complex the PyTree of states for the neural network is, when defining a task in terms of initial and target states for the biomechanical effector.

Therefore, in AbstractTask the initial values for substates are specified (example) as a pairing of a lambda that selects the substate to be initialized from the full model state, with a PyTree of data with the same structure as that substate. Then TaskTrainer performs a series of equinox.tree_at surgeries based on this mapping:

feedbax/feedbax/train.py

Lines 516 to 521 in 2ce8b1c

for where_substate, init_substates in trial_specs.inits.items():
init_states = eqx.tree_at(
where_substate,
init_states,
init_substates,
)

As far as TaskTrainer is concerned, it would be sufficient to provide these pairings as tuple[Callable, PyTree]. However, in general it seems to make sense for the pairing to be a mapping—first, because there should only be at most a single initialization provided for each substate. But also, if the user—or a function in feedbax.plot -- wants to access the intialization data from the trial specification, it is more convenient to write (say) trial_spec.init['mechanics.effector'] or even trial_spec.init[lambda state: state.mechanics.effector] than to have to figure out which tuple[Callable, PyTree] contains the Callable that refers to the part of the state they're interested in.

Unfortunately, lambdas cannot simply be used as keys in a mapping. If I define a dict with a key lambda state: state.something, and later try to get the value associated with a newly-defined key lambda s: s.something, I'll encounter a KeyError because lambdas are not hashed according to the function they represent, but by their memory address.

Thus we use WhereDict, which is an OrderedDict that enables limited use of lambdas as keys. In particular, it uses dis to parse the LOAD_ATTR operations in the lambda's bytecode, and constructs an equivalent string representation. For example, lambda x: x.foo.bar is parsed as "foo.bar". This only works when the lambda takes a single argument, and returns a single (nested) attributed access on the argument.

Both a lambda and its string representation can be used as WhereDict key:

assert my_where_dict['foo.bar'] is my_where_dict[lambda x: x.foo.bar]

The downside is the overhead of dis.Bytecode, such that WhereDict is about 100x slower to construct, and 500-1000x slower to access, than OrderedDict. In practice this is not a big deal for our use case, as we only need to do a single construct and a single access on each training batch, leading to an overhead of about 125 us, where a batch normally takes at least 20,000 us. Also, it's unlikely the user will initialize more than a few substates separately, i.e. no more than a few entries per WhereDict.

Is there a better or faster way to do this?

I've considered that it might work to specify the state initialization as a prefix of the model state rather than using a lambda combined with a substate. However, this does not solve the access problem—we'd still need to map to the prefixes with string keys like "mechanics.effector" so that the user could refer to them easily; however 1) those keys would no longer have a lawful relationship to the surgeries performed to assign the initial states, and 2) the user would still have to access the appropriate leaves from the prefix tree.

Multiple feedback channels and `MultiModel`

SimpleFeedback allows for multiple channels of feedback to the neural network, with different delays and noise. For example, one typical configuration is to feed back "proprioceptive" variables (e.g. arm joint configuration, or muscle states) at a short delay, and "visual" variables (e.g. position of the end of the arm) at a longer delay.

In particular, SimpleFeedback:

  1. has a field channels: PyTree[Channel]. At construction time, the user supplies a PyTree[ChannelSpec] or a container of mappings; this is used to construct a PyTree[Channel], which is used to construct a MultiModel . However, see: #3.
  2. has a model stage "update_feedback":

    feedbax/feedbax/bodies.py

    Lines 175 to 183 in b73dfb8

    "update_feedback": Stage(
    callable=lambda self: self._feedback_module,
    where_input=lambda input, state: jax.tree_map(
    lambda spec: spec.where(state.mechanics),
    self._feedback_specs,
    is_leaf=lambda x: isinstance(x, ChannelSpec),
    ),
    where_state=lambda state: state.feedback,
    ),

MultiModel is a subclass of AbstractModel that has a field models: PyTree[AbstractModel], and expects to be passed state and input whose tree structures match that of models. When called, it maps the models, inputs, and states:

feedbax/feedbax/_model.py

Lines 129 to 146 in b73dfb8

def __call__(
self,
input: ModelInput,
state: PyTree[StateT, "T"],
key: PRNGKeyArray,
) -> StateT:
# TODO: This is hacky, because I want to pass intervenor stuff through entirely. See `staged`
return jax.tree_map(
lambda model, input_, state, key: model(
ModelInput(input_, input.intervene), state, key
),
self.models,
input.input,
state,
self._get_keys(key),
is_leaf=lambda x: isinstance(x, AbstractModel),
)

The PyTree structure of input.input matches models because of the tree_map performed in the definition of where_input for the "update_feedback" stage.

The structure of states matches too, because MultiModel (like all AbstractModel subclasses) provides an init method that returns a PyTree[ChannelState], and this is used to generate any initial state that is passed to the model.

Is there a better way to include a PyTree of similar components in a model, that are all executed as part of a single model stage? With the current approach, intervenors can be added to individual Channel objects, but it may be kind of inconvenient to refer to those objects (e.g. my_simple_feedback.channels.models['vision']).

I'm not sure how vmapping could be used here, as different channels can carry data of different shapes and dtypes.

The use of ModelInput in MultiModel is also not ideal, in particular because I think it makes sense for MultiModel to be a subclass of AbstractModel and not AbstractStagedModel; however, ModelInput is specifically used for carrying intervenor parameters along with other model inputs, and intervenors are associated with instances of AbstractStagedModel, and not with AbstractModel in general. See #12 for a more general discussion of ModelInput.

Improve the realism and generality of models of musculoskeletal geometry

So far, the only muscle models implemented in Feedbax are variants of the Virtual Muscle Model, and their usage is limited to pre-defined muscoloskeletal geometries. For example, MuscledArm has constant moment arms that determine how muscle forces produce torques on the joints of a two-link arm. The values of the moment arms can be changed at construction time, however the number of muscles cannot. In a realistic model, moment arms should also vary with muscle length -- currently, Feedbax does not model this.

MotorNet has a more flexible and realistic approach to modeling musculoskeletal geometry, by allowing the user to define arbitrary wrappings, i.e. attachment points of muscles on the skeleton, and by modeling variable moment arms.

I'd like to follow MotorNet's lead, here.

The abstract-final pattern and generics: should `AbstractState` be eliminated?

If we follow the abstract-final pattern strictly, such as by setting strict=True when subclassing equinox.Module, then

  • The base class AbstractState would need to include an AbstractVar for every field that appears in every subclass;
  • Every final subclass of AbstractState would need to implement every one of those fields.

Clearly this doesn't make sense, since different types of state PyTrees usually don't share any fields.

This might also be a reason we should expect to see issues with generic typing of AbstractModel[StateT], where StateT is bound to AbstractState. However, my understanding is that type invariance should be preserved as long as we respect the abstract-final pattern when subclassing whichever AbstractState subclass is ultimately used as the type argument for a final subclass of AbstractModel.

I suspect the solution is:

  1. Replace StateT with a type variable bound to equinox.Module;
  2. Use different base classes (AbstractFooState) that inherit from equinox.Module, for different final subclasses of AbstractModel. Each of these should respect the abstract-final pattern.

Support Python>=3.9

Feedbax is built on some features introduced in Python 3.11, such as Self.

Some users may want to use Feedbax along with some packages that do not support 3.11 yet. Ideally we would support all active Python versions, though 3.8 is reaching end-of-life this year and I think we can exclude it.

I intend to keep developing in 3.11, mainly because of the newer typing features. I don't foresee problems with 3.9 support, aside from the effect on typing. As for typing, I think support for 3.9 would involve simplifying some annotations, e.g. replacing with Any or flagging with type: ignore.

Adding intervenors to model ensembles

The current implementations of schedule_intervenors and add_intervenors in feedbax.intervene are insensitive to model ensembles. If the intervenor parameters include arrays, these arrays won't necessarily contain the batch dimension of the model ensemble. Thus, when intervenors are added to an ensembled model PyTree, an error may be raised on the filter_vmap calls used for model ensembling in feedbax.train and feedbax.task.AbstractTask.

A solution that should work immediately is to apply vmap/get_ensemble to acquire an ensemble from a function that returns a single model with the intervenor added. However, this means intervenors must be added at the time of model construction.

The desired behaviour is that we should be able to add an intervention to an existing ensembled model, at any time.

Composition and instantiation of muscle models

The composition of the models in feedbax.mechanics.muscle is the most complicated in the library.

Class diagram of feedbax.mechanics.muscle
classDiagram
  class AbstractActivationFunction {
  }
  class AbstractFLVFunction {
  }
  class AbstractForceFunction {
  }
  class AbstractMuscle {
    activation_func : AbstractVar[AbstractActivationFunction]
    force_func : AbstractVar[AbstractFLVFunction]
    n_muscles : AbstractVar[int]
    noise_func : AbstractVar[Optional[Callable[[Array, Array, Array], Array]]]
    change_n_muscles(n_muscles: int) 'AbstractMuscle'
  }
  class AbstractMuscleState {
    activation : AbstractVar[Array]
    length : AbstractVar[Array]
    tension : AbstractVar[Array]
    velocity : AbstractVar[Array]
  }
  class AbstractVirtualMuscleShortenFactor {
    c_v : Tuple[float, float]
  }
  class ActivationFilter {
    tau_act : float
    tau_deact : float
    init()*
    input_size()
    vector_field(t: None, state: Array, input: Array)
  }
  class HillShortenFactor {
  }
  class LillicrapScottForceLength {
    beta : float
    omega : float
  }
  class VirtualMuscle {
    activation_func
    bounds
    force_func
    n_muscles : int
    noise_func : Optional[Callable[[Array, Array, Array], Array]]
    init() VirtualMuscleState
  }
  class VirtualMuscleActivationFunction {
    a_f : float
    n_f : Tuple[float, float]
  }
  class VirtualMuscleFLVFunction {
    force_length
    force_passive_1
    force_passive_2
    force_velocity
  }
  class VirtualMuscleForceLength {
    beta : float
    omega : float
    rho : float
  }
  class VirtualMuscleForcePassive1 {
    c1 : float
    k1 : float
    l_r1 : float
  }
  class VirtualMuscleForcePassive2 {
    c2 : float
    k2 : float
    l_r2 : float
  }
  class VirtualMuscleForceVelocity {
    a_v : Tuple[float, float, float]
    b_v : float
    shorten_denom_factor_func
    v_max : float
  }
  class VirtualMuscleShortenFactor {
  }
  class VirtualMuscleState {
    activation : Array
    length : Array
    tension : Array
    velocity : Array
  }
  VirtualMuscle ..> VirtualMuscleState
  AbstractMuscle ..> AbstractMuscleState
  HillShortenFactor --|> AbstractVirtualMuscleShortenFactor
  LillicrapScottForceLength --|> AbstractForceFunction
  VirtualMuscle --|> AbstractMuscle
  VirtualMuscleActivationFunction --|> AbstractActivationFunction
  VirtualMuscleFLVFunction --|> AbstractFLVFunction
  VirtualMuscleForceLength --|> AbstractForceFunction
  VirtualMuscleForcePassive1 --|> AbstractForceFunction
  VirtualMuscleForcePassive2 --|> AbstractForceFunction
  VirtualMuscleForceVelocity --|> AbstractForceFunction
  VirtualMuscleShortenFactor --|> AbstractVirtualMuscleShortenFactor
  VirtualMuscleState --|> AbstractMuscleState
  AbstractActivationFunction --* VirtualMuscle : activation_func
  AbstractFLVFunction --* VirtualMuscle : force_func
  AbstractForceFunction --* VirtualMuscleFLVFunction : force_length
  AbstractForceFunction --* VirtualMuscleFLVFunction : force_velocity
  AbstractForceFunction --* VirtualMuscleFLVFunction : force_passive_1
  AbstractForceFunction --* VirtualMuscleFLVFunction : force_passive_2
  AbstractVirtualMuscleShortenFactor --* VirtualMuscleForceVelocity : shorten_denom_factor_func

For example, a muscle model typically has a force-length-velocity (FLV) function which determines how its force output relates to kinematic variables. This function is typically a composite of several different force functions; e.g. force-length, force-velocity, passive force. Sometimes, only one of these functions changes between implementations. Therefore the composition of muscle models is dependency-inverted so that it's easier to swap out specific components.

Here's an example of construction of a VirtualMuscle model:

def todorov_li_2004_virtualmuscle(
n_muscles: int = 1,
noise_func: Optional[Callable] = None,
params: PyTree[float] = TODOROV_LI_VIRTUALMUSCLE_PARAMS,
):
"""Muscle model from Todorov & Li 2004.
!!! Note ""
Simplifies the Brown et al. 1999 Virtual Muscle Model:
1. Omits the first passive element, PE1.
2. Uses averages of the fast and slow twitch parameters from Brown 1999.
Arguments:
n_muscles: The number of muscles to model.
noise_func: Generates noise to add to the muscle force.
Has the signature `noise_func(input, force, key) -> Array`, where
`input` is the input to the muscle model.
params: The parameters for the Virtual Muscle Model.
"""
return VirtualMuscle(
n_muscles,
activation_func=VirtualMuscleActivationFunction(**params["activation"]),
force_func=VirtualMuscleFLVFunction(
force_length=VirtualMuscleForceLength(**params["force_length"]),
force_velocity=VirtualMuscleForceVelocity(
**params["force_velocity"],
shorten_denom_factor_func=VirtualMuscleShortenFactor(
**params["shorten"]
),
),
force_passive_1=lambda length, velocity: 0,
force_passive_2=VirtualMuscleForcePassive2(**params["force_passive_2"]),
),
noise_func=noise_func,
)

The default parameters are defined as a dict:

"""Virtual Muscle Model parameters used by Todorov & Li, 2004."""
TODOROV_LI_VIRTUALMUSCLE_PARAMS = dict(
force_length=dict(
beta=1.93, # slow/fast avg
omega=1.03, # slow/fast avg is 1.035
rho=1.87, # slow/fast avg
),
force_velocity=dict(
a_v=(-3.12, 4.21, -2.67), # slow/fast avg
b_v=0.62, # slow/fast avg
v_max=-5.72, # slow/fast avg is -5.725
),
force_passive_2=dict( # identical for slow/fast
c2=-0.02,
k2=-18.7,
l_r2=0.79,
),
shorten=dict(
c_v=(1.38, 2.09), # slow/fast avg is (1.335, 2.085)
),
activation=dict(
n_f=(2.11, 4.16), # slow/fast avg (2.11, 4.155),
a_f=0.56,
),
#! unused
force_passive_1=dict(
c1=0.0,
k1=1.0,
l_r1=0.0,
),
)

Is there a better way to instantiate these models? Note that currently, the muscle components (like VirtualMuscleForceLength and VirtualMuscleForcePassive) do not have default values for their fields.

Training, and vmapping over ensembles of models

Some parts of TaskTrainer.__call__ are not fit for vmapping when training multiple models in parallel:

  • Logging and plotting operations should not be vmapped, but some logging should still be done. Say we're sending some validation plots to tensorboard: do we log one model, or all of them separately? Do we send all of the loss values separately, or just statistics?
  • It also doesn't make sense to vmap over a tqdm progress bar, though we still want to keep the progress bar when training ensembles.

The current solution is to add several if ensembled: blocks (example) at appropriate points in TaskTrainer.__call__, at which we:

  • either add (or don't add) batch dimensions to the arrays meant to store the training history, e.g. losses;
  • either split (or don't split) random keys;
  • either apply (or don't apply) vmap to TaskTrainer._train_step, optimizer.init, etc. prior to using these functions.

It would be nice for __call__ itself to be vmappable, but I'm not sure how this could be achieved. Perhaps we could use something like jax.experimental.host_callback to pass data back to logging functions, but I don't see how this would solve the progress bar issue.

Intervenors, and an issue with the abstract-final pattern

As mentioned in #19, intervention operations are prepended to the existing stages of a model's state operations. The intervention operations are defined by inserting AbstractIntervenor objects into the intervenors: Mapping[str, Sequence[AbstractIntervenor]] field of any AbstractStagedModel.

Currently, every subclass of AbstractStagedModel needs to implement that field. This is annoying: the field is generally just an empty dict that is later filled up with stuff, so it is always implemented in the same way. Developers are expected to repeat themselves.

We could avoid this by implementing intervenors directly in AbstractStagedModel, but this would violate the abstract-final pattern. We haven't been enforcing the pattern by subclassing equinox.Module with strict=True, yet.

Is there a way to avoid implementing intervenors in the base class?

There are a few other places this same issue appears in Feedbax.

  • the intervene field, in all final subclasses of feedbax.task.AbstractTaskTrialSpec;
  • the intervention_specs and intervention_specs_validation fields in all final subclasses of AbstractTask;
  • the label field in all final subclasses of AbstractIntervenor—when we use schedule_intervenor, this is used to assign a unique label to each intervenor among all the intervenors belonging to a model, so that it can be matched up with its trial-by-trial parameters, which are specified as a flat PyTree (#12). The user generally does not have to assign to this field.

Generalized loss terms from trial specifications

Currently:

  • the init field of AbstractTaskTrialSpec specifies 1) lambdas that pick out subtrees of the model state, and 2) replacements for those subtrees, to use to initialize the state on a given trial.
  • for each batch of training trials, TaskTrainer passes the following to the loss function: 1) the evaluated trajectories of model states, 2) the trial specifications. Thus, every AbstractLoss subclass in feedbax.loss works through hardcoded references to parts of states and trial_specs. For example, feedbax.loss.EffectorPositionLoss is a function of the difference states.mechanics.effector.pos - trial_specs.target.pos.
  • the target field of AbstractTaskTrialSpec is a CartesianState that only specifies the state of a single effector.

But what if we want to (say) include losses that penalize the position of two different effectors, given by CartesianState leaves at different locations in the states PyTree? Then trial_specs.target should specify two different CartesianState targets. Should the relationship between the leaves of trial_specs.target and the leaves of states be hardcoded into a subclass of AbstractLoss? This could lead to a proliferation of AbstractLoss classes. (However, see #24.)

Instead, trial_specs.target could be defined similarly to trial_specs.init. In that case, each member of target would provide 1) a lambda that picks out a part of the state, and 2) a target trajectory for that state. A loss term could be automatically constructed from each entry -- something like target.where(states) - target.value.

I think some losses should still be defined the same way they already are. But this feature might make it easier to train more complex models, since the user will only need to specify a couple of lambdas when subclassing AbstractTask, instead of needing to write multiple subclasses of AbstractLoss as well.

Does the model need access to the task, at construction time?

Currently, when constructing a model that takes task information as an input, we first need to construct the task the model will be asked to perform.

For example, SimpleFeedback has a neural network whose inputs include the task information (which depends on the task) as well as the sensory feedback, which depends on the user's specification of which state variables are provided as feedback.

Thus a typical model construction looks something like:

def point_mass_nn(
    task,
    *,
    key: PRNGKeyArray,
    dt: float = 0.05, 
    mass: float = 1., 
    hidden_size: int = 50, 
    encoding_size: Optional[int] = None,
    n_steps: int = 100, 
    feedback_delay_steps: int = 0,
    feedback_noise_std: float = 0.0,
):        
    key1, key2 = jr.split(key)
    
    system = PointMass(mass=mass)
    mechanics = Mechanics(DirectForceInput(system), dt)
    
    feedback_spec = dict(
        where=lambda state: (
            state.plant.skeleton.pos,
            state.plant.skeleton.vel,
        ),
        delay=feedback_delay_steps,
        noise_std=feedback_noise_std,
    )
    
    # Determine network input size.
    input_size = SimpleFeedback.get_nn_input_size(
        task, mechanics, feedback_spec=feedback_spec
    )
    
    net = SimpleStagedNetwork(
        input_size,
        hidden_size,
        out_size=system.input_size, 
        encoding_size=encoding_size,
        key=key1,
    )

    body = SimpleFeedback(net, mechanics, feedback_spec=feedback_spec, key=key2)
    
    return Iterator(body, n_steps)

However, maybe we'll want to construct models without knowing beforehand what the task will exactly be.

This is one possible use case for a subclass of AbstractIntervenor: we can add inputs to neural networks as interventions on the state of the network's input layer(s), prior to the forward pass of the network.

Similarly, upon construction of SimpleFeedback, we could modify the effective input size of net to include the sensory feedback variables, either through a method provided by (in this case) SimpleStagedNetwork that returns a modified model, or else by using add_intervenor.

There are a couple of related problems here.

  • If we use add_intervenor to add an intervention to SimpleStagedNetwork, those inputs will not be reflected by its attribute input_size. In principle we could infer the change to input size caused by certain kinds of intervenors, but this would probably be unwise if those input variables do not actually appear in the input to the network module, but are routed in by an intervenor added to a module higher up in the model tree (e.g. SimpleFeedback).
  • Currently, an intervenor can only select its input state as some node(s) from the state PyTree operated on by the model it belongs to. The intervenor cannot take as input any part of its model's input argument. Thus there is no way for us to pass the task information—which shows up as input at the top-level of the model PyTree—into the intervenor and have it affect the state. In principle we could try to fix this by allowing the intervenor to also access the inputs of the model it belongs to, but currently the input to an intervenor is used only to supply it with its intervention parameters.

Clearly there either needs to be a change to how we handle model inputs (#12) that will make it easier to use intervenors for this purpose, or else we should abandon this intervenor use case altogether.

`AbstractLTISystem` is not general enough

The vector_field method of feedbax.dynamics.AbstractLTISystem concatenates the position and velocity arrays so it can use the $\dot{\mathbf{x}}=\mathbf{Ax}+\mathbf{Bu}$ form. Here is a slightly edited version:

    def vector_field(
        self, 
        t: float, 
        state: CartesianState,
        input: Float[Array, "input"]
    ) -> CartesianState:
        """Returns time derivatives of the system's states.        
        """
        force = input + state.force
        state_ = jnp.concatenate([state.pos, state.vel])       
        d_y = self.A @ state_ + self.B @ force

        return CartesianState(pos=d_y[:2], vel=d_y[2:], force=None)

This may be inefficient due to the concatenation, but also it does not describe linear time-invariant systems in general, whose state won't necessarily be CartesianState.

Perhaps the position-velocity concatenation logic, if it needs to be kept, can be moved to PointMass, and the return and state parameter types of AbstractLTISystem could be replaced by Array. However, while I imagine it's possible to avoid the concatenation by splitting up the A and B matrices, I think that would require AbstractLTISystem to have knowledge of the split at construction time...

Currently:

  • Position and velocity are stored in separate arrays in CartesianState. It is convenient for the user to be able to refer to them separately. I guess we could make a special object that stores them in a single array and slices them out -- though the force field of CartesianState would need to be kept separate, anyway.
  • The only abstract components of AbstractLTISystem are the fields A, B, and C (the latter is a placeholder, for now). These are abstract because of wanting to respect the abstract-final design pattern, combined with the way inheritance works for PointMass.
  • PointMass inherits from both AbstractLTISystem and AbstractSkeleton[CartesianState], the latter of which inherits from AbstractDynamicalSystem (like AbstractLTISystem does).
    • By inheriting from AbstractLTISystem, PointMass obtains its linear vector field. As mentioned above, we could perhaps override vector_field to wrap the parent method to deal in CartesianState instead of a concatenated state array.
    • By inheriting from AbstractSkeleton, PointMass is forced to provide the usual kinematics methods.

One alternative is to 1) make LTISystem final, and 2) make PointMass inherit only from AbstractSkeleton while possessing an instance of LTISystem as an attribute.

This all seems a bit muddled to me. Advice is welcome!

Automatically caching `AbstractTask.validation_trials`

It makes sense to compute the validation set of task trials only once for each AbstractTask object. This is possible if we use caching.

class AbstractTask(eqx.Module):

    @cached_property
    def validation_trials(self):
        return self._validation_trials

    @abstractproperty 
    def _validation_trials(self):
        ...

However, I'd rather not have properties/methods be implemented with private names in subclasses, if these are to appear in documentation with docstrings describing their particular implementations.

Potential solutions:

  • Use the private property approach shown above, and manually override the docstrings for validation_trials in subclasses.
  • Use a single abstract property validation_trials that is not cached, and use cached_property in every subclass. Suggest to developers that they use cached_property to avoid unnecessary overhead.
  • Require developers to define __init__ -- currently, all AbstractTask subclasses use the default dataclass init -- and construct the validation set there, assigning it as a regular attribute. This is unappealing to me since it introduces a bunch of boilerplate we otherwise wouldn't have to implement on subclassing, and it may be unclear to users that the only purpose of the __init__ (and its long signature seen in the docs) is to construct the validation set.

What is the best practice, here?

Structure of mechanical models: effector states, and aggregation of dynamic and static updates

This is a (perhaps overly) general issue about the design of AbstractPlant and Mechanics.

In brief, an AbstractPlant is a subclass of

  1. AbstractDynamicalSystem: AbstractPlant implements a general vector_field method, which aggregates the vector_field methods of all the AbstractDynamicalSystem instances referred to by the abstract property AbstractPlant.dynamics_spec. Final subclasses of AbstractPlant must implement this property.
  2. AbstractStagedModel: final subclasses of AbstractPlant must implement a model_spec property that defines the sequence of kinematic/static updates.

On the other hand, a Mechanics instance wraps an AbstractPlant instance in order to discretize it, associate it with a Diffrax solver, and execute all its kinematic and dynamic updates. Currently, it also handles the effector states.

In this issue I discuss this arrangement in more detail. At the end, I raise a couple of potential issues:

  • Should the vector fields be aggregated, and updated through a single Diffrax solver step?
  • Which class should manage the effector states?

AbstractPlant

AbstractPlant objects provide most of the components of a biomechanical model. In particular, they aggregate the differential equations as well as the kinematic updates the mechanical components.

For example, an AbstractPlant object that models a muscled arm (e.g. feedbax.mechanics.plant.MuscledArm) may bring together:

  1. a differential equation describing the physics of skeletal movement due to torques generated by muscles, and belonging to an AbstractSkeleton object;
  2. another differential equation describing the dynamics of muscle activation, belonging to an AbstractMuscle object;
  3. some kinematic methods belonging to the AbstractPlant object itself, that directly calculate the muscle length and contraction velocity given the current state of the skeleton;
  4. a method of the same AbstractMuscle object as in (2) that determines the forces generated by the muscle, given its current length and velocity. These are converted into torques after being passed back to the AbstractPlant object.

AbstractPlant subclasses must implement the property dynamics_spec, which aggregates all the AbstractDynamicalSystem components that should be included in the dynamics step for the entire plant. For example, MuscledArm includes a first-order filter for the muscle activation, as well as an ODE describing the skeletal dynamics, and both of these are referred to in its dynamics_spec:

@cached_property
def dynamics_spec(self) -> Mapping[str, DynamicsComponent[PlantState]]:
"""Specifies the components of the muscled arm dynamics."""
return dict(
{
"muscle_activation": DynamicsComponent(
dynamics=self.activator,
where_input=lambda input, state: input,
where_state=lambda state: state.muscles.activation,
),
#! is this applying the torques twice? since arm will do `input_torque + state.torque`
"skeleton": DynamicsComponent(
dynamics=self.skeleton,
where_input=lambda input, state: state.skeleton.torque,
where_state=lambda state: state.skeleton,
),
}
)

The entries in dynamics_spec are automatically aggregated into a single field by the method AbstractPlant.vector_field, which fulfils AbstractPlant's implementation of AbstractDynamicalSystem:

d_state = jax.tree_map(jnp.zeros_like, state)
for component in self.dynamics_spec.values():
d_state = eqx.tree_at(
component.where_state,
d_state,
component.dynamics.vector_field(
t, component.where_state(state), component.where_input(input, state)
),
)
return d_state

AbstractPlant also inherits from AbstractStagedModel and therefore defines a model_spec, which composes the various static operations that are performed on the plant state (e.g. calculating muscle length and velocity, or clipping states to bounds).

Mechanics

An instance of AbstractPlant describes continuous dynamics, as do the vector fields defined in dynamics_spec and aggregated in vector_field.

Mechanics associates the continuous dynamics of an AbstractPlant instance with a Diffrax solver and a time step, thereby discretizing them. It includes a property _term which returns diffrax.ODETerm(self.plant.vector_field), which is used in the Diffrax solver call:

https://github.com/mlprt/feedbax/blob/147fb42f8793db55d06c8e432e4627a3679a801a/feedbax/mechanics/mechanics.py#L127C1-L135C10

Mechanics inherits solely from AbstractStagedModel[MechanicsState], and when called it:

  • updates the plant state by 1) calling the AbstractPlant instance (i.e. the kinematic operations in its model_spec), then 2) executing a step of the dynamics solver.
  • syncs the effector state with the configuration state. For example, the configuration state of a typical arm model consists of joint angles and angular velocities, whereas the effector state is typically the Cartesian state of some point on the arm -- the very end of it. The actual methods for doing these calculations are part of the skeleton object.

Here is how these operations are defined in Mechanics.model_spec:

@property
def model_spec(self) -> OrderedDict[str, ModelStage]:
"""Specifies the stages of the model."""
return OrderedDict(
{
"convert_effector_force": ModelStage(
callable=lambda self: self.plant.skeleton.update_state_given_effector_force,
where_input=lambda input, state: state.effector.force,
where_state=lambda state: state.plant.skeleton,
),
"kinematics_update": ModelStage(
# the `plant` module directly implements non-ODE operations
callable=lambda self: self.plant,
where_input=lambda input, state: input,
where_state=lambda state: state.plant,
),
"dynamics_step": ModelStage(
callable=lambda self: self.dynamics_step,
where_input=lambda input, state: input,
where_state=lambda state: state,
),
"get_effector": ModelStage(
callable=lambda self: wrap_stateless_callable(
self.plant.skeleton.effector, pass_key=False
),
where_input=lambda input, state: state.plant.skeleton,
where_state=lambda state: state.effector,
),
}
)

Potential issues

Aggregation of vector fields -> single solver step

Should all the component vector fields be explicitly aggregated in AbstractPlant objects, to be associated with a single Diffrax solver step in Mechanics? It's not generally the case that the fields are coupled, and in principle they could be stepped separately.

Sometimes aggregating the dynamics can lead to confusing situations. For example, see the admonition in the docstring for AbstractMuscle.__call__ -- it's a little conceptually weird to have to include the muscle's ActivationFilter (which models something that would really occur inside a muscle fibre) as part of the AbstractPlant, and not the AbstractMuscle. One potential solution is to have AbstractMuscle inherit from both AbstractStagedModel and AbstractDynamicalSystem, just like AbstractPlant does.

Defining and updating the effector

The loss function is generally calculated in effector (or operational) coordinates. For example, a reaching task defines its reach goals in Cartesian coordinates. Similarly, force perturbations may be applied to the effector as linear forces, which must be converted to torques. However, only the configuration state (e.g. joint angles/angular velocities) is part of the ODE describing the skeleton dynamics. It does not make sense to return derivatives for the effector state through the Diffrax solver.

That is why the effector state is currently part of MechanicsState and not of AbstractSkeletonState. It may be best to change that, as all the relevant kinematic methods belong to an AbstractSkeleton, and the effector state is rightly just an alternative representation of the skeleton state specifically (assuming we won't be modeling soft tissue). Of course, AbstractSkeletonState could be separated into configuration and effector fields, and dynamics_spec could ensure that only state.configuration is passed to the skeleton's vector_field.

Once effector state is removed from MechanicsState, it will only contain two fields: plant and solver. I think it's kind of annoying that the user has to refer to state.mechanics.plant.skeleton instead of state.mechanics.skeleton, and I'd prefer if Mechanics could work as a kind of wrapper that operates on PlantState directly. However, the solver state needs to be passed from each time step to the next. The solver could be moved to PlantState, but that creates a bit of an asymmetry between the model and state hierarchies, given that numerical solution is the responsibility of of Mechanics and not AbstractPlant.

Which point on the skeleton is considered the effector may vary between tasks or experiments. So far, it's always been the endpoint of an arm, or (trivially) a point mass. There may also be models with multiple effectors. One way to deal with complexity here is to 1) keep all the Cartesian states for the skeleton, and 2) use generalized target specifications (#10) so the user can define losses for arbitrary parts of those states.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.