sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
546 stars 133 forks source link

hardcoded transfer of data and model to cpu triggers UserWarning #1161

Closed psteinb closed 6 days ago

psteinb commented 1 month ago

Describe the bug For SNPE_C and likely all SNPE posteriors, the train function here performs a hardcoded transfer of all data to the "cpu" to build a flow (calling self._build_neural_net). This triggers a UserWarning from sbi/utils/user_input_checks.py:444 :

sbi/utils/user_input_checks.py:444: UserWarning: Mismatch bet
ween the device of the data fed to the embedding_net and the device of the embedding_net's weights. Fed data has device 'cpu' vs embedding_net weights have device 'cuda:0'. 
Automatically switching the embedding_net's device to 'cpu', which could otherwise be done manually using the line `embedding_net.to('cpu')`.

After looking through the code and debug stepping through the train function, I saw that this warning is not honored during training.

To Reproduce Please add a minimal code example that reproduces the problem:

  1. python 3.12.3, since sbi 0.22.0
  2. code as in https://sbi-dev.github.io/sbi/tutorial/05_embedding_net/ (but sending everything to a GPU)
  3. error message see above

Expected behavior I'd expect the UserWarning not to trigger.

Additional context I am unclear why the hardcoded transfer to a cpu device was put there. The UserWarning originates from sbi/neural_nets/flow.py:341 in 0.22.0. Perhaps this problem is resolved in main due to the switch to zuko?