mlprt / feedbax

Optimal feedback control + interventions, in JAX.
https://docs.lprt.ca/feedbax
Apache License 2.0
1 stars 0 forks source link

Eliminate `AbstractTaskTrialSpec` #35

Closed mlprt closed 5 months ago

mlprt commented 6 months ago

Currently, typing in feedbax.task is a mess. A source of this mess is AbstractTaskTrialSpec.

https://github.com/mlprt/feedbax/blob/1c239e63e6ac029503050cc2cfa13c1d94e7e084/feedbax/task.py#L150-L167

Once (#10) the target field is specified in a generalized way similarly to inits, the only field that will vary systematically between tasks will be inputs. In that case I suspect it will make more sense to eliminate subclassing of AbstractTaskTrialSpec, and instead define it as a generic final class, something like:

InputT = TypeVar("InputT", Module, Array)

class TaskTrialSpec(Module, Generic[InputsT]):
    inits: WhereDict
    inputs: InputsT
    targets: WhereDict
    intervene: Mapping[str, Array]

This would also save developers from explicitly including the intervene field in subclasses of AbstractTaskTrialSpec, because subclassing would be unnecessary; see #20.

I am still not sure how (if at all) to explicitly associate the structure of a given type of inputs, and the type of model that is compatible with a task. In principle, it could involve multiple fields of TaskTrialSpec; see #14.

mlprt commented 5 months ago

AbstractTaskTrialSpec has been replaced by TaskTrialSpec as of 4f735f434ededd122ddeee6957fd911a9e8870c7.