maho3 / ltu-ili

Robust ML in Astro
https://ltu-ili.readthedocs.io/en/latest/
43 stars 8 forks source link

Make jupyter interface #74

Closed maho3 closed 1 year ago

maho3 commented 1 year ago

This change provides a tutorial jupyter notebook to demonstrate an IPython interface with ltu-ili. It also includes several `quality-of-life' improvements to the ltu-ili interface to make it easier to specify in Jupyter.

This should be merged only after the SNLE PR #54 .

DeaglanBartlett commented 1 year ago

Nice work! There are a few warnings which are raised when I run the tutorial notebook and some potentially undesirable behaviour, so it would be nice to address these. These are:

  1. When doing posterior, summaries = runner(loader=loader), the warning

    /home/bartlett/anaconda3/envs/ili-sbi/lib/python3.10/site-packages/sbi/utils/posterior_ensemble.py:142: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
    self._weights = torch.tensor(weights) / sum(weights)

    is raised. This is an issue with the sbi package so there is nothing the user can do about it, but it would be nice if you could wrap the call to the function which uses this with something which suppresses this warning.

  2. When generating samples for a random input in the SNPE example, I get

    /home/bartlett/anaconda3/envs/ili-sbi/lib/python3.10/site-packages/nflows/transforms/lu.py:80: UserWarning: torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangularand will be removed in a future PyTorch release.
    torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
    X = torch.triangular_solve(B, A).solution
    should be replaced with
    X = torch.linalg.solve_triangular(A, B). (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2189.)
    outputs, _ = torch.triangular_solve(

    Please can this be addressed or the warning caught and not printed?

  3. When using ltu-ili's validation metrics, one gets

    /home/bartlett/anaconda3/envs/ili-sbi/lib/python3.10/site-packages/seaborn/_oldcore.py:1498: FutureWarning: is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
    if pd.api.types.is_categorical_dtype(vector):

    and

    /home/bartlett/anaconda3/envs/ili-sbi/lib/python3.10/site-packages/seaborn/_oldcore.py:1119: FutureWarning: use_inf_as_na option is deprecated and will be removed in a future version. Convert inf values to NaN before operating instead.
    with pd.option_context('mode.use_inf_as_na', True):

    and

    /home/bartlett/anaconda3/envs/ili-sbi/lib/python3.10/site-packages/seaborn/axisgrid.py:118: UserWarning: The figure layout has changed to tight
    self._figure.tight_layout(*args, **kwargs)

    so please make these changes.

  4. For SNLE, in the section which generates samples for a random input, we get

    /home/bartlett/anaconda3/envs/ili-sbi/lib/python3.10/site-packages/sbi/inference/posteriors/mcmc_posterior.py:172: UserWarning: `.log_prob()` is deprecated for methods that can only evaluate the
            log-probability up to a normalizing constant. Use `.potential()` instead.
  5. The cell beginning with

    # calculate and plot the rank statistics to describe univariate posterior coverage
    # (note, with MCMC this takes a while)

    takes a very very long time to run (~30 minutes). These seems far too long for a tutorial. Is there any way this can be sped up? The same is true for TARP.

  6. When downloading the CAMELS data, you need to already have a directory called "toy" is the same directory as the notebook, but this is not guaranteed. So a call to mkdir is required before downloading the data or state in the text that the user needs to ensure that this exists.

  7. When training the model for CAMELS, there is a device mismatch between the model and the data

    /home/bartlett/anaconda3/envs/ili-sbi/lib/python3.10/site-packages/sbi/utils/user_input_checks.py:435: UserWarning: Mismatch between the device of the data fed to the embedding_net and the device of the embedding_net's weights. Fed data has device 'cpu' vs embedding_net weights have device 'cuda:0'. Automatically switching the embedding_net's device to 'cpu', which could otherwise be done manually using the line `embedding_net.to('cpu')`.
    warnings.warn(

    It would be nice if this was handled correctly without relying on the automatic switching.

CompiledAtBirth commented 1 year ago

I have to double-check for point 5 also. Yesterday, I could not generate more plots for the slides because TARP was taking forever when trying to rerun my previous examples. I assumed I made a mistake somewhere.

maho3 commented 1 year ago

Addressing Deaglan's comment:

  1. Fixed with warning catching.
  2. This deprecation warning is caused by nflows mismatching with modern pytorch. This will hopefully be fixed by the nflows people, but we can't address it in the tutorial without wrapping sbi's DirectPosterior in a new class or wrapping the samples = posterior.sample((1000,), torch.Tensor(x[ind]).to(device)) line in the tutorial with a warning filter. Both seem like ugly fixes for now, so we may just need to live with this one. Thoughts? I have added warning catching [CITE] for the self-implemented DirectSampler class
  3. Fixed these with warning catching around seaborn pairplot. This is a bug in matplotlib, and will be fixed in future versions.
  4. Not yet done...
  5. This is a bigger problem which has been switched to Issue #77 for future work.
  6. Fixed this with os.makedirs
  7. Fixed this. Turns out, everything is initialized on the cpu, then moved to the device specified in the SN** instantiation. So one need not move the embedding_net to gpu anyway. :)

Everything also needs a full re-run, to make things clean for public-facing. To be finished...

DeaglanBartlett commented 1 year ago

Nice! I agree that both the options to (2) seem a bit ugly, so we should probably just leave that one as it is

maho3 commented 1 year ago

Regarding @DeaglanBartlett 's comment 4, we might need to live with that one too. Basically, it flags that warning when you call log_prob on an SNLE or SNRE posterior, but it outputs the same thing as potential anyway. However, sbi's potential function for NeuralPosteriorEnsemble has a bug wherein it works for cpu backend, but not for gpu backend. So we again have the choice of either wrapping the .potential call in a warnings catch or just using log_prob and dealing with those (informative) warnings. I'm in the camp of the latter, but open to suggestions.

If we're okay with this ^ , then I have rerun the whole tutorial notebook and made it nice and clean for public viewing. After a final check, we should be able to merge.

maho3 commented 1 year ago

Actually, I had the thought that instead of suppressing warnings in the source code, maybe we should just 'ignore' them all at the head of the tutorial? That way, people can choose to look at them or not, if they want an idea of what's under the hood. Thoughts?

DeaglanBartlett commented 1 year ago

I agree that if the potential function has that bug, then we should continue with log_prob and live with the warning for now. I also think that suppressing all warnings at the top of the notebook seems sensible to make the example cleaner for the user. So I'm happy for you to rerun the notebook with the warnings suppressed and then I'll merge the changes.

maho3 commented 1 year ago

done! we can use this PR as a record for when those warnings eventually turn into errors :)