probabilists / lampe

Likelihood-free AMortized Posterior Estimation with PyTorch
https://lampe.readthedocs.io
MIT License
119 stars 11 forks source link

Add differentiable coverage probability regularizer #15

Closed macio232 closed 3 months ago

macio232 commented 10 months ago

Resolve https://github.com/probabilists/lampe/issues/14

macio232 commented 10 months ago

@francois-rozet please have a look at this one :)

I added all the functionalities expected in this PR, however, I can imagine you would see some parts implemented/organized differently so I keep it as a draft. Maybe we should arrange a call to discuss it?

As already discussed, the new dependency torchsort is slightly problematic to install. To develop the PR I prepared a conda environment.yml file, but eventually, this has to be handled in setup.py. This, however, goes beyond my expertise, and I hope you can give me a hand here :)

francois-rozet commented 10 months ago

Hello @macio232, I am looking at this right now. I read the paper more carefully to understand the method.

Is it expected that the formula written in the docstrings does not correspond to Eq. (12) in your paper? I also don't understand where the $\max$ and $j$ are coming from.

image

In addition, could we drop the need for the sort? The rank $\alpha(p_\phi, \theta_i, xi)$ is computed as an expectation of differentiable step functions. The index of $\alpha(p\phi, \theta_i, x_i)$ in the batch is computed with a sort, but such index can be seen as a rank as well. Basically,

$$\text{index}(\alpha(p_\phi, \theta_i, x_i)) = \sumj 1[\alpha(p\phi, \theta_j, xj) < \alpha(p\phi, \theta_i, x_i)]$$

This is actually mentioned in Section 3.2. Did you try this or just went for the sort? Also, did you try with a soft step function (e.g. sigmoid) instead of the STE?

Anyway, if you want to keep the sort as in your paper, what I propose is to only import torchsort in the losses and refer to their repository for installation. Like this torchsort would not be a dependency of LAMPE.

francois-rozet commented 10 months ago

I just pused a commit that moves inference components into different files, which should make it easier to read/add/maintain code. I think you can move your two losses into a new file (e.g. diffcal.py or something similar).

macio232 commented 10 months ago

1.

Is it expected that the formula written in the docstrings does not correspond to Eq. (12) in your paper?

It depends on your preference :) What I put in the docstring is supposed to describe what is the idea behind the method (corresponding to Eg. (7) in the paper), while Eq. (12) precisely describes how this idea is implemented. I can modify docstrings so that they match Eq. (12) but this requires introducing additional symbols, and probably some explanation. Do you prefer the short but more abstract formulation, or the very precise but long?

2.

I also don't understand where the $\max$ and $j$ are coming from.

In addition, could we drop the need for the sort?

This is actually mentioned in Section 3.2. Did you try this or just went for the sort?

Short: I tried, didn't work well. Long: Please see openreview for the discussion. Even longer: Summary of this post.

Also, did you try with a soft step function (e.g. sigmoid) instead of the STE?

Not in the full experiments, but as a PoC and it didn't work well. But I can imagine that under certain circumstances it could work. In fact, a recent arXiv preprint does exactly this if I understood correctly.

5.

Anyway, if you want to keep the sort as in your paper, what I propose is to only import torchsort in the losses and refer to their repository for installation. Like this torchsort would not be a dependency of LAMPE.

Sound good. Where would you see the information about the additional requirement (torchsort) places? Docstrings? README in the repo?

Another direction is to try to replace torchsort with diffsort which is easier to setup. But this is a method based on sorting networks which means it is slower. And I have no idea about the "quality" of the gradients it gives. At some point, I will try to evaluate it empirically.

  1. Summary

As you see, the idea of using calibration (conservativennes) error as a regularizer can be implemented in many ways. The idea of this PR is to implement what I described in the paper. I hope to see people exploring the alternatives and sharing them in the library. However, there is a question about how would you like to maintain it? a) As a single loss class that is parametrized to handle all the versions? b) Separate class for every version?

Let me discuss briefly version a) on the example of "sorting-based" vs "direct". Once we go for direct, there I the question of levels at which the loss should be computed. I see at least four options immediately (and the combinations of them):

macio232 commented 10 months ago

I just pused a commit that moves inference components into different files, which should make it easier to read/add/maintain code. I think you can move your two losses into a new file (e.g. diffcal.py or something similar).

Btw, do you have an idea how to call the two loss classes? As a placeholder, I used Cal{NRE/NPE}Loss as done in my original code and the paper but I feel like this is not the best name. In fact, the problem is that I didn't give a clear name for the method in the paper, and now we have to solve it.

The idea of the method is to control the (Expected) Coverage Probability - how about CP{NRE/NPE}Loss/ECP{NRE/NPE}Loss/CovProb{NRE/NPE}Loss or similar?

francois-rozet commented 10 months ago

Do you prefer the short but more abstract formulation, or the very precise but long?

As you prefer. The docstring can be as long (see NRE) or short (see FMPE) as you want, it is just a means to give some context. By the way, with the new file organization, you can have a long explanation in the module docstring and a short one in the classes docstring (see NRE). Do as you think is better.

Where would you see the information about the additional requirement (torchsort) places? Docstrings?

Preferably in the module docstring. You can even write a little "installation" section like

Installation
------------

blabla `torchsort` blabla

.. code-block:: bash

    $ pip install torchsort

The idea of this PR is to implement what I described in the paper.

I agree, this is probably the most sensible way. Discussion about improvements can come later.

However, there is a question about how would you like to maintain it? a) As a single loss class that is parametrized to handle all the versions? b) Separate class for every version?

It depends on the differences I guess. If a flag that switches between two behaviors is enough, I guess the same class is fine. If the code is completely different, I would say two classes. However, don't make your class significantly more complex to be "future proof". Let's go for something simple for now.

Btw, do you have an idea how to call the two loss classes?

Maybe DCP for differentiable coverage probability?

macio232 commented 9 months ago

@francois-rozet I applied all the modifications that you suggested.

  1. I tired building the documentation locally to check if it renders correctly, but I get the following error message
    
    ❯ sphinx-build . html
    Running Sphinx v7.2.6

Configuration error: There is a programmable error in your configuration file:

Traceback (most recent call last): File "/usr/lib/python3.11/site-packages/sphinx/config.py", line 358, in eval_config_file exec(code, namespace) # NoQA: S102 ^^^^^^^^^^^^^^^^^^^^^ File "<...>/lampe/docs/conf.py", line 6, in import lampe ModuleNotFoundError: No module named 'lampe'


despite being in an environment where `lampe` is installed in dev mode. Test run without any problems in the same environment.
2. Tutorials
a) Would you like to see tutorials in the same or separate PR?
b) They would be a 1 to 1 copy of the NRE/NPE tutorials. Should I copy all the step-by-step descriptions or just copy the code and refer the reader to the original tutorial?
francois-rozet commented 9 months ago

I tired building the documentation locally to check if it renders correctly, but I get the following error message

Did you install the docs dependencies? Your env has Spinx 7.2.6 which is not what we use currently.

pip install -r docs/requirements.txt

Also what is the version of python in your env? I usually stick to 3.9, because they are still issues with 3.11.

Would you like to see tutorials in the same or separate PR?

In this PR.

They would be a 1 to 1 copy of the NRE/NPE tutorials. Should I copy all the step-by-step descriptions or just copy the code and refer the reader to the original tutorial?

I think a single tutorial (either NRE or NPE + DCP) is enough. For the explanations, the level of the FMPE tutorial should be enough.

francois-rozet commented 9 months ago

By the way, for the tests, I think it will be necessary to skip testing DCP stuff if torchsort is not installed.

P.S. I will review next week.

macio232 commented 9 months ago

Did you install the docs dependencies?

I thought so but that was not the case. Now works fine.

By the way, for the tests, I think it will be necessary to skip testing DCP stuff if torchsort is not installed.

Indeed. The tests don't use backprop so we could fall back to torch's default sorting. What do you think?

francois-rozet commented 9 months ago

The tests should try to back-prop through the loss. I think using pytest.mark.skipif would be more appropriate.

macio232 commented 5 months ago

P.S. I will review next week.

Long week it has been :) Do you want me to prepare the notebook(s) before your review?

francois-rozet commented 5 months ago

Ah sorry @macio232, my comment was not clear, I was waiting for you to finish the PR before I would review it. Do you want me to do a first pass before you add tests and tutorials?

francois-rozet commented 3 months ago

Hello @macio232, I looked at your PR and it seems good to me. Thank you for including the installation instructions.

I have rebased the PR to the master branch. Some tests were failing and I fixed them, but there is also a bug on the master branch due to the release of NumPy 2.0... I'll fix that bug then merge your PR.

francois-rozet commented 3 months ago

I tried the DCP losses by replacing the normal losses (e.g. NPELoss) with the DCP ones (e.g. DCPNPELoss) and they seem to work as expected :+1: They are obviously more expensive, but that is not really surprising.

I will merge this PR now.