blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Enable fitlering of AdaptationInfo #674

Closed andrewdipper closed 4 months ago

andrewdipper commented 4 months ago

This PR allows for AdaptationInfo to be filtered during the window_adaptation process. Always storing all the auxiliary information can lead to substantial excess memory usage.

In response to: https://github.com/blackjax-devs/blackjax/issues/667

I put some initial changes here but it seems like this change applies to the other adaptation algorithms. Perhaps it's worth moving return_all_adapt_info and get_filter_adapt_info_fn to adaptation/base.py and enabling filtering for the other adaptation algorithms. But I thought I'd get some perspectives before going down that path.

Let me know any thoughts / changes

junpenglao commented 4 months ago

I put some initial changes here but it seems like this change applies to the other adaptation algorithms. Perhaps it's worth moving return_all_adapt_info and get_filter_adapt_info_fn to adaptation/base.py and enabling filtering for the other adaptation algorithms. But I thought I'd get some perspectives before going down that path.

+1, we should apply this to other adaptation algorithms.

andrewdipper commented 4 months ago

I added the AdaptationInfo filtering for the other adaptation algorithms. The second test that is pretty similar to the first is there to make sure the filtering is general to the other adaptation algorithms.