stanford-iprl-lab / torchfilter

Bayesian filters in PyTorch
https://stanford-iprl-lab.github.io/torchfilter
MIT License
146 stars 23 forks source link

Allowing `forward_loop()` to return more information #5

Open aravindbattaje opened 3 years ago

aravindbattaje commented 3 years ago

Currently forward_loop() (https://github.com/stanford-iprl-lab/torchfilter/blob/81f41224df38dbd0c45512c07854251285533134/torchfilter/base/_filter.py) only returns state_predictions over time sequences of controls and observations. This doesn't allow for example to easily use state_covariances once forward_loop() has been called.

I suggest torchfilter.Base.Filter should enforce returning two arguments on all forward calls.

  1. State predictions (as it is already)
  2. Additional state information, as a list or list of tuples

2 can be (possibly a sequence of) state covariances for Kalman, or maybe even third-order moment or something like for particle filters.

For now, I have hacked the filters I'm using to return state covariances, but I think a clean solution would be great.

brentyi commented 3 years ago

This seems like a reasonable suggestion -- thanks for the note @aravindbattaje!

My guess is that a hook might also work for extracting intermediate filter parameters. Roughly:

# Create list of belief covariances; initialize with the first one
covariances = [kalman_filter.belief_covariance]

# Create hook for extracting covariance after each call to forward()
hook = kalman_filter.register_forward_hook(
    lambda self, inputs, outputs: covariances.append(self.belief_covariance)
)

# Forward loop over observations/controls
kalman_filter.forward_loop(...)

# Remove hook when done
hook.remove()

# Do something with covariances
print(covariances)

Any thoughts? Perhaps this is too complex for something as simple as grabbing posterior covariances?

For adding a second return argument, were you thinking of having forward_loop return a list of outputs from forward?

I've been toying with the idea of a proper v1 of this library to reduce the number of dictionaries that are being passed around and better support for things like manifolds/Lie groups, so very open to suggestions that would break the current API.

aravindbattaje commented 3 years ago

Whoops, sorry for the delayed response. This somehow skipped my notifications.

Hooks can indeed be very good to introspect or extract data not normally associated with the filter, but yeah, as you said, I think it might be better to have covariances (or even other moments) also returned as it is pretty central to the function of a filter.

For adding a second return argument, were you thinking of having forward_loop return a list of outputs from forward?

Yes, I think that would be the most straight forward way. That's how I hacked it, so that training (at all stages) can be done with NLL loss.

I really like the idea of v1 and I'd be happy to contribute to the next version too. In addition to Lie groups, we are working with having multiple connected filters, and multiple observation models per filter. Some API changes maybe useful to facilitate that.