sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
589 stars 152 forks source link

Circular import errors #1301

Closed tvwenger closed 11 hours ago

tvwenger commented 6 days ago

Describe the bug This is a duplicate of https://github.com/sbi-dev/sbi/issues/1158 which was supposedly fixed by https://github.com/sbi-dev/sbi/pull/1179

When importing posterior_nn or any other net from sbi.neural_nets there is a circular import error:

ImportError: cannot import name 'classifier_nn' from partially initialized module 'sbi.neural_nets.factory' (most likely due to a circular import) (/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/factory.py)

To Reproduce python 3.12.7 sbi 0.23.2

from sbi.neural_nets import posterior_nn

Traceback:

``` ImportError Traceback (most recent call last) Cell In[1], line 1 ----> 1 from sbi.neural_nets import posterior_nn File ~/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/__init__.py:1 ----> 1 from sbi.neural_nets.factory import ( 2 classifier_nn, 3 flowmatching_nn, 4 likelihood_nn, 5 posterior_nn, 6 posterior_score_nn, 7 ) 10 def __getattr__(name): 11 if name in ["CNNEmbedding", "FCEmbedding", "PermutationInvariantEmbedding"]: File ~/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/factory.py:9 5 from typing import Any, Callable, Optional, Union 7 from torch import nn ----> 9 from sbi.neural_nets.net_builders.classifier import ( 10 build_linear_classifier, 11 build_mlp_classifier, 12 build_resnet_classifier, 13 ) 14 from sbi.neural_nets.net_builders.flow import ( 15 build_made, 16 build_maf, (...) 27 build_zuko_unaf, 28 ) 29 from sbi.neural_nets.net_builders.flowmatching_nets import ( 30 build_mlp_flowmatcher, 31 build_resnet_flowmatcher, 32 ) File ~/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/net_builders/__init__.py:1 ----> 1 from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator 2 from sbi.neural_nets.net_builders.classifier import ( 3 build_linear_classifier, 4 build_mlp_classifier, 5 build_resnet_classifier, 6 ) 7 from sbi.neural_nets.net_builders.flow import ( 8 build_made, 9 build_maf, (...) 20 build_zuko_unaf, 21 ) File ~/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/net_builders/categorial.py:9 6 from torch import Tensor, nn, unique 8 from sbi.neural_nets.estimators import CategoricalMassEstimator, CategoricalNet ----> 9 from sbi.utils.nn_utils import get_numel 10 from sbi.utils.sbiutils import ( 11 standardizing_net, 12 z_score_parser, 13 ) 14 from sbi.utils.user_input_checks import check_data_device File ~/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/utils/__init__.py:72 63 from sbi.utils.user_input_checks import ( 64 check_estimator_arg, 65 check_prior, (...) 69 validate_theta_and_x, 70 ) 71 from sbi.utils.user_input_checks_utils import MultipleIndependent ---> 72 from sbi.utils.get_nn_models import posterior_nn, likelihood_nn, classifier_nn File ~/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/utils/get_nn_models.py:9 5 from warnings import warn 7 from torch import nn ----> 9 from sbi.neural_nets.factory import classifier_nn as classifier_nn_moved_to_neural_nets 10 from sbi.neural_nets.factory import likelihood_nn as likelihood_nn_moved_to_neural_nets 11 from sbi.neural_nets.factory import posterior_nn as posterior_nn_moved_to_neural_nets ImportError: cannot import name 'classifier_nn' from partially initialized module 'sbi.neural_nets.factory' (most likely due to a circular import) (/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/factory.py) ```

Expected behavior No error.

Additional context While the issue should have been fixed in the associated PR, it still appears. As described in https://github.com/sbi-dev/sbi/issues/1158, a workaround is to first import sbi.utils:

import sbi.utils
from sbi.neural_nets import posterior_nn

This works fine.

My specific use-case, however, involves taking the likelihood density estimator from sbi and plugging it into a pymc model for inference. pymc uses cloudpickle to distribute the model across multiple processes for parallel sampling. I don't know exactly what's going on under the hood, but pymc ultimately fails to sample the model due to a pickling error related to this circular import.

Traceback (most recent call last):
  File "/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/pymc/sampling/parallel.py", line 123, in _unpickle_step_method
    self._step_method = cloudpickle.loads(self._step_method)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/__init__.py", line 1, in <module>
    from sbi.neural_nets.factory import (
  File "/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/factory.py", line 9, in <module>
    from sbi.neural_nets.net_builders.classifier import (
  File "/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/net_builders/__init__.py", line 1, in <module>
    from sbi.neural_nets.net_builders.categorial import build_categoricalmassestimator
  File "/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/net_builders/categorial.py", line 9, in <module>
    from sbi.utils.nn_utils import get_numel
  File "/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/utils/__init__.py", line 72, in <module>
    from sbi.utils.get_nn_models import posterior_nn, likelihood_nn, classifier_nn
  File "/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/utils/get_nn_models.py", line 9, in <module>
    from sbi.neural_nets.factory import classifier_nn as classifier_nn_moved_to_neural_nets
ImportError: cannot import name 'classifier_nn' from partially initialized module 'sbi.neural_nets.factory' (most likely due to a circular import) (/home/twenger/miniconda3/envs/galstruct/lib/python3.12/site-packages/sbi/neural_nets/factory.py)

I could probably figure out a hack to make this work, but ultimately we should pin down the circular import error anyway, which will fix my specific problem.

janfb commented 5 days ago

Thanks a lot for reporting this @tvwenger , and for proposing a fix as well 🚀