wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
76 stars 12 forks source link

[FR] Ability to draw cheap AR samples with subset-condition-predict procedure with `nps.ar_predict` #10

Open tom-andersson opened 1 year ago

tom-andersson commented 1 year ago

In our AR CNP paper in ICLR, we describe some ways to make AR sampling cheaper through a subset-AR sample-condition-predict procedure. For example, in Appendix K:

The AR samples were drawn on a sparse 70x70 grid spanning the entire input space to save compute time. The ConvCNP model was then conditioned on these AR samples and the predictive mean was computed over the dense 280x280 target space.

This is particularly useful in environmental applications where one may have a large and dense target set, avoiding a prohibitive computational cost and out-of-distribution context set density during AR sampling. The subsetting factor is a lever you can tune to trade off the detail captured by the resulting samples vs. the computational cost (while avoiding an OOD sample density).

Currently I implemented this functionality manually by wrapping around nps.ar_predict. However, it would be great to have this supported under the hood in nps.ar_predict to remove the burden on the user for accessing this useful procedure. This could be handled with a new ar_subset kwarg (terminology open to debate), which if not None turns on the subsetting procedure, defining the target set subsetting factor for drawing the AR samples over (which will then be added to the context set and conditioned on for the final forward pass). Moreover, if xt is a tuple of gridded coordinates, the subsetting should be applied along each dimension (i.e. before flattening the tuple into a tensor of target locations) to ensure that the AR sample locations are spread uniformly in input space.

As always, I could give this a go myself with some hand holding but can't guarantee when :-(. I'd also be a bit worried about handling target batching and nps.Aggregate cases.

tom-andersson commented 1 year ago

P.S. It would also be great to also expose the N / K cost reduction approach described in the 'Consistency and the AR design space' paragraph.

wesselb commented 1 year ago

Hey @tom-andersson! Thanks for opening an issue. :) You're totally right. This would be a really good addition.

This might actually be pretty simple to support. In particular, to get the noiseless samples right (ft in the code), I believe that xt here would just have to be swapped out for xt_fine (so not the subset), and that would be all! And then there would need to be some additional logical to get mean, var, and yt right too.

tom-andersson commented 10 months ago

This subset-condition-predict feature is now supported in DeepSensor (my package that wraps neuralprocesses and adds functionality for environmental sciences). It may still be useful to support directly in neuralprocesses however!

wesselb commented 10 months ago

@tom-andersson That's an awesome feature in deepsensor! And beautiful documentation too. :)