sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
586 stars 150 forks source link

Circular import in `from sbi.neural_nets import posterior_nn` #1249

Open jroulet opened 2 months ago

jroulet commented 2 months ago

Describe the bug Hello, I got an ImportError when trying to import posterior_nn from neural_nets.

To Reproduce sbi version 0.23.1

>>> from sbi.neural_nets import posterior_nn
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../sbi/neural_nets/__init__.py", line 1, in <module>
    from sbi.neural_nets.factory import (
  File ".../sbi/neural_nets/factory.py", line 9, in <module>
    from sbi.neural_nets.net_builders.classifier import (
  File ".../sbi/neural_nets/net_builders/__init__.py", line 1, in <module>
    from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator
  File ".../sbi/neural_nets/net_builders/categorial.py", line 9, in <module>
    from sbi.utils.nn_utils import get_numel
  File ".../sbi/utils/__init__.py", line 72, in <module>
    from sbi.utils.get_nn_models import posterior_nn, likelihood_nn, classifier_nn
  File ".../sbi/utils/get_nn_models.py", line 9, in <module>
    from sbi.neural_nets.factory import classifier_nn as classifier_nn_moved_to_neural_nets
ImportError: cannot import name 'classifier_nn' from partially initialized module 'sbi.neural_nets.factory' (most likely due to a circular import) (.../sbi/neural_nets/factory.py)

Expected behavior No error.

Additional context I can work around this by first importing sbi.utils.

michaeldeistler commented 2 months ago

Hi there, thanks for reporting this! We are aware of this issue and, technically, it is already fixed. However, you are still getting the error because we had to keep imports in sbi.utils purely for backwards compatibility (to still allow from sbi.utils import posterior_nn). We will remove that import in the next major release and the bug will go away.

As you already suggest, the simplest fix is to first import the inference class (or first import sbi.utils), e.g.

from sbi.inference import NPE
from sbi.neural_nets import posterior_nn