alan-turing-institute / deepsensor

A Python package for tackling diverse environmental prediction tasks with NPs.
https://alan-turing-institute.github.io/deepsensor/
MIT License
72 stars 15 forks source link

Fix failing tests due to PT -> TF type conversion (plum issue?) #115

Closed tom-andersson closed 2 months ago

tom-andersson commented 4 months ago

We have failing tests in test_active_learning.py where somewhere within a neuralprocesses call stack a PyTorch tensor is trying to be converted to a TensorFlow tensor (TypeError: No promotion rule for torch.Tensor and tensorflow.python.framework.ops.EagerTensor.).

I believe these tests are failing due to recent (~2 weeks ago) changes in plum meaning an error is raised that was not being raised before. There have been some changes to plum’s promotion handling since the tests started failing, and no changes in deepsensor. I’m not 100% sure though and will have to investigate further.

cc @wesselb @davidwilby

wesselb commented 4 months ago

Hmmm, @tom-andersson, I've not seen any strangeness on my side. Would you be able to create a MWE for the error? Happy to look more closely!

It seems that TF & PyTorch tensors are getting mixed, possibly because the model is created with the one type and run with the other.

tom-andersson commented 4 months ago

Thanks @wesselb! to clarify I didn't mean plum was at fault, more just a hypothesis that plum was now catching a potential bug, as you suggest. However, the failing test module (test_active_learning.py) only uses TensorFlow (it imports deepsensor.tensorflow, which sets everything to be TF objects), and there are no references to torch objects in the tests. Also, tests are passing when I run them locally at HEAD.

@davidwilby any thoughts on what's going on here?

wesselb commented 4 months ago

Also, tests are passing when I run them locally at HEAD.

Hmmm, this is really weird. Could it possibly be to do with the Python version and/or order in which the tests are run? Do they also locally pass when you install a fresh env and run the entire test suite as a whole?

davidwilby commented 3 months ago

Sorry that I haven't been able to look at this until now.

I have now been able to replicate this behaviour locally with python 3.8, neuralprocesses 0.2.6, plum-dispatch 2.4.1, tensorflow 2.13.1, torch 2.1.2. (Also same behaviour for neuralprocesses 0.2.2 and plum-dispatch 2.2.2.)

Full package versions here ``` Package Version ---------------------------- -------------- absl-py 2.1.0 affine 2.4.0 aiohttp 3.9.5 aiosignal 1.3.1 algebra 1.2.1 anyio 4.4.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asciitree 0.3.3 asttokens 2.4.1 astunparse 1.6.3 async-lru 2.0.4 async-timeout 4.0.3 attrs 23.2.0 Babel 2.15.0 backcall 0.2.0 backends 1.6.5 backends-matrix 1.3.0 beartype 0.18.5 beautifulsoup4 4.12.3 black 24.4.2 bleach 6.1.0 cachetools 5.3.3 certifi 2024.6.2 cffi 1.16.0 cftime 1.6.4 chardet 5.2.0 charset-normalizer 3.3.2 click 8.1.7 click-plugins 1.1.1 cligj 0.7.2 cloudpickle 3.0.0 colorama 0.4.6 comm 0.2.2 contourpy 1.1.1 coverage 7.5.3 coveralls 4.0.1 cycler 0.12.1 dask 2023.5.0 debugpy 1.8.1 decorator 5.1.1 deepsensor 0.3.6 defusedxml 0.7.1 distlib 0.3.8 distributed 2023.5.0 dm-tree 0.1.8 docopt 0.6.2 exceptiongroup 1.2.1 executing 2.0.1 fasteners 0.19 fastjsonschema 2.19.1 fdm 0.4.1 filelock 3.15.1 flatbuffers 24.3.25 fonttools 4.53.0 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2024.6.0 gast 0.4.0 gcsfs 2024.6.0 google-api-core 2.19.0 google-auth 2.30.0 google-auth-oauthlib 1.0.0 google-cloud-core 2.4.1 google-cloud-storage 2.17.0 google-crc32c 1.5.0 google-pasta 0.2.0 google-resumable-media 2.7.1 googleapis-common-protos 1.63.1 grpcio 1.64.1 h11 0.14.0 h5py 3.11.0 httpcore 1.0.5 httpx 0.27.0 idna 3.7 importlib_metadata 7.1.0 importlib_resources 6.4.0 iniconfig 2.0.0 ipykernel 6.29.4 ipython 8.12.3 ipywidgets 8.1.3 isoduration 20.11.0 jedi 0.19.1 Jinja2 3.1.4 joblib 1.4.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.22.0 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter_client 8.6.2 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.1 jupyter_server_terminals 0.5.3 jupyterlab 4.2.2 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.2 jupyterlab_widgets 3.0.11 keras 2.13.1 kiwisolver 1.4.5 libclang 18.1.1 locket 1.0.0 Markdown 3.6 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib 3.7.5 matplotlib-inline 0.1.7 mdurl 0.1.2 mistune 3.0.2 mlkernels 0.4.0 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 mypy-extensions 1.0.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 netCDF4 1.6.5 networkx 3.1 NeuralProcesses 0.2.6 notebook 7.2.1 notebook_shim 0.2.4 numcodecs 0.12.1 numpy 1.24.3 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.18.1 nvidia-nvjitlink-cu12 12.5.40 nvidia-nvtx-cu12 12.1.105 oauthlib 3.2.2 opt-einsum 3.3.0 overrides 7.7.0 packaging 24.1 pandas 2.0.3 pandocfilters 1.5.1 parameterized 0.9.0 parso 0.8.4 partd 1.4.1 pathspec 0.12.1 pexpect 4.9.0 pickleshare 0.7.5 pillow 10.3.0 pip 24.0 pkgutil_resolve_name 1.3.10 platformdirs 4.2.2 pluggy 1.5.0 plum-dispatch 2.4.1 pooch 1.8.2 prometheus_client 0.20.0 prompt_toolkit 3.0.47 proto-plus 1.23.0 protobuf 4.25.3 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 16.1.0 pyasn1 0.6.0 pyasn1_modules 0.4.0 pycparser 2.22 Pygments 2.18.0 pyparsing 3.1.2 pyproj 3.5.0 pyproject-api 1.6.1 pyshp 2.3.1 pytest 8.2.2 pytest-cov 5.0.0 python-dateutil 2.9.0.post0 python-json-logger 2.0.7 python-slugify 8.0.4 pytz 2024.1 PyYAML 6.0.1 pyzmq 26.0.3 qtconsole 5.5.2 QtPy 2.4.1 rasterio 1.3.10 referencing 0.35.1 requests 2.32.3 requests-oauthlib 2.0.0 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.7.1 rioxarray 0.13.4 rpds-py 0.18.1 rsa 4.9 scikit-learn 1.3.2 scipy 1.10.1 seaborn 0.13.2 Send2Trash 1.8.3 setuptools 69.5.1 shapely 2.0.4 six 1.16.0 sniffio 1.3.1 snuggs 1.4.7 sortedcontainers 2.4.0 soupsieve 2.5 stack-data 0.6.3 stheno 1.4.1 sympy 1.12.1 tblib 3.0.0 tensorboard 2.13.0 tensorboard-data-server 0.7.2 tensorflow 2.13.1 tensorflow-estimator 2.13.0 tensorflow-io-gcs-filesystem 0.34.0 tensorflow-probability 0.21.0 termcolor 2.4.0 terminado 0.18.1 text-unidecode 1.3 threadpoolctl 3.5.0 tinycss2 1.3.0 tomli 2.0.1 toolz 0.12.1 torch 2.1.2 tornado 6.4.1 tox 4.15.1 tox-gh-actions 3.2.0 tqdm 4.66.4 traitlets 5.14.3 triton 2.1.0 types-python-dateutil 2.9.0.20240316 typing_extensions 4.5.0 tzdata 2024.1 uri-template 1.3.0 urllib3 2.2.1 varz 0.8.1 virtualenv 20.26.2 wbml 0.4.1 wcwidth 0.2.13 webcolors 24.6.0 webencodings 0.5.1 websocket-client 1.8.0 Werkzeug 3.0.3 wheel 0.43.0 widgetsnbextension 4.0.11 wrapt 1.16.0 xarray 2023.1.0 yarl 1.9.4 zarr 2.16.1 zict 3.0.0 zipp 3.19.2 ```

All tests pass with python 3.11, neuralprocesses 0.2.2, plum-dispatch 2.2.2, tensorflow 2.15.0.post1, torch 2.2.0.

Full package versions here ``` Not replicated. python 3.11 Package Version Editable project location ----------------------------- --------------- -------------------------- absl-py 2.0.0 accessible-pygments 0.0.4 affine 2.4.0 aiohttp 3.9.3 aiosignal 1.3.1 alabaster 0.7.16 algebra 1.2.1 anyio 4.4.0 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asciitree 0.3.3 asttokens 2.4.1 astunparse 1.6.3 async-lru 2.0.4 attrs 23.2.0 Babel 2.14.0 backends 1.6.0 backends-matrix 1.3.0 beartype 0.17.0 beautifulsoup4 4.12.2 black 24.1.1 bleach 6.1.0 cachetools 5.3.2 certifi 2023.11.17 cffi 1.16.0 cfgv 3.4.0 cftime 1.6.3 chardet 5.2.0 charset-normalizer 3.3.2 click 8.1.7 click-plugins 1.1.1 cligj 0.7.2 cloudpickle 3.0.0 colorama 0.4.6 comm 0.2.1 contourpy 1.2.0 coverage 6.5.0 coveralls 3.3.1 cycler 0.12.1 dask 2024.1.1 debugpy 1.8.0 decorator 5.1.1 deepsensor 0.3.6 defusedxml 0.7.1 distlib 0.3.8 distributed 2024.1.1 dm-tree 0.1.8 docopt 0.6.2 docutils 0.18.1 executing 2.0.1 fasteners 0.19 fastjsonschema 2.19.1 fdm 0.4.1 filelock 3.13.1 flatbuffers 23.5.26 fonttools 4.47.2 fqdn 1.5.1 frozenlist 1.4.1 fsspec 2023.12.2 gast 0.5.4 gcsfs 2023.12.2.post1 google-api-core 2.16.1 google-auth 2.26.2 google-auth-oauthlib 1.2.0 google-cloud-core 2.4.1 google-cloud-storage 2.14.0 google-crc32c 1.5.0 google-pasta 0.2.0 google-resumable-media 2.7.0 googleapis-common-protos 1.62.0 greenlet 3.0.3 grpcio 1.60.0 h11 0.14.0 h5py 3.10.0 httpcore 1.0.5 httpx 0.27.0 identify 2.5.36 idna 3.6 imagesize 1.4.1 importlib-metadata 7.0.1 iniconfig 2.0.0 ipykernel 6.28.0 ipython 8.20.0 ipywidgets 8.1.3 isoduration 20.11.0 jedi 0.19.1 Jinja2 3.1.3 joblib 1.3.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.20.0 jsonschema-specifications 2023.12.1 jupyter 1.0.0 jupyter-book 0.15.1 jupyter-cache 0.6.1 jupyter_client 8.6.0 jupyter-console 6.6.3 jupyter_core 5.7.1 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.1 jupyter_server_terminals 0.5.3 jupyterlab 4.2.2 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.2 jupyterlab_widgets 3.0.11 keras 2.15.0 kiwisolver 1.4.5 latexcodec 2.0.1 libclang 16.0.6 linkify-it-py 2.0.2 locket 1.0.0 Markdown 3.5.2 markdown-it-py 2.2.0 MarkupSafe 2.1.3 matplotlib 3.8.2 matplotlib-inline 0.1.6 mdit-py-plugins 0.3.5 mdurl 0.1.2 mistune 3.0.2 ml-dtypes 0.2.0 mlkernels 0.4.0 mpmath 1.3.0 msgpack 1.0.7 multidict 6.0.5 mypy 1.9.0 mypy-extensions 1.0.0 myst-nb 0.17.2 myst-parser 0.18.1 nbclient 0.7.4 nbconvert 7.16.4 nbformat 5.9.2 nest-asyncio 1.5.8 netCDF4 1.6.5 networkx 3.2.1 neuralprocesses 0.2.2 nodeenv 1.8.0 notebook 7.2.1 notebook_shim 0.2.4 numcodecs 0.12.1 numpy 1.26.3 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu12 2.19.3 nvidia-nvjitlink-cu12 12.3.101 nvidia-nvtx-cu12 12.1.105 oauthlib 3.2.2 opt-einsum 3.3.0 overrides 7.7.0 packaging 23.2 pandas 2.2.0 pandocfilters 1.5.1 parameterized 0.9.0 parso 0.8.3 partd 1.4.1 pathspec 0.12.1 pexpect 4.9.0 pillow 10.2.0 pip 24.0 platformdirs 4.1.0 pluggy 1.4.0 plum-dispatch 2.2.2 pooch 1.8.0 pre-commit 3.7.0 prometheus_client 0.20.0 prompt-toolkit 3.0.43 protobuf 4.23.4 psutil 5.9.7 ptyprocess 0.7.0 pure-eval 0.2.2 pyarrow 15.0.0 pyasn1 0.5.1 pyasn1-modules 0.3.0 pybtex 0.24.0 pybtex-docutils 1.0.3 pycparser 2.22 pydata-sphinx-theme 0.15.1 Pygments 2.17.2 pyparsing 3.1.1 pyproj 3.6.1 pyproject-api 1.6.1 pyshp 2.3.1 pytest 8.0.0 pytest-cov 4.1.0 python-dateutil 2.8.2 python-json-logger 2.0.7 python-slugify 8.0.3 pytz 2024.1 PyYAML 6.0.1 pyzmq 25.1.2 qtconsole 5.5.2 QtPy 2.4.1 rasterio 1.3.9 referencing 0.32.1 requests 2.31.0 requests-oauthlib 1.3.1 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.7.0 rioxarray 0.15.1 rpds-py 0.16.2 rsa 4.9 ruff 0.4.2 scikit-learn 1.4.0 scipy 1.12.0 seaborn 0.13.2 Send2Trash 1.8.3 setuptools 69.5.1 shapely 2.0.4 six 1.16.0 sniffio 1.3.1 snowballstemmer 2.2.0 snuggs 1.4.7 sortedcontainers 2.4.0 soupsieve 2.5 Sphinx 5.0.2 sphinx-book-theme 1.0.1 sphinx-comments 0.0.3 sphinx-copybutton 0.5.2 sphinx_design 0.3.0 sphinx_external_toc 0.3.1 sphinx-jupyterbook-latex 0.5.2 sphinx-multitoc-numbering 0.1.3 sphinx-thebe 0.2.1 sphinx-togglebutton 0.3.2 sphinxcontrib-applehelp 1.0.7 sphinxcontrib-bibtex 2.5.0 sphinxcontrib-devhelp 1.0.5 sphinxcontrib-htmlhelp 2.0.4 sphinxcontrib-jsmath 1.0.1 sphinxcontrib-qthelp 1.0.6 sphinxcontrib-serializinghtml 1.1.9 SQLAlchemy 2.0.25 stack-data 0.6.3 stheno 1.4.1 sympy 1.12 tabulate 0.9.0 tblib 3.0.0 tensorboard 2.15.1 tensorboard-data-server 0.7.2 tensorflow 2.15.0.post1 tensorflow-estimator 2.15.0 tensorflow-io-gcs-filesystem 0.35.0 tensorflow-probability 0.23.0 termcolor 2.4.0 terminado 0.18.1 text-unidecode 1.3 threadpoolctl 3.2.0 tinycss2 1.3.0 toolz 0.12.1 torch 2.2.0 tornado 6.4 tox 4.12.1 tox-gh-actions 3.2.0 tqdm 4.66.1 traitlets 5.14.1 triton 2.2.0 types-python-dateutil 2.9.0.20240316 typing_extensions 4.9.0 tzdata 2023.4 uc-micro-py 1.0.2 uri-template 1.3.0 urllib3 2.1.0 varz 0.8.1 virtualenv 20.25.0 wbml 0.4.1 wcwidth 0.2.13 webcolors 24.6.0 webencodings 0.5.1 websocket-client 1.8.0 Werkzeug 3.0.1 wheel 0.43.0 widgetsnbextension 4.0.11 wrapt 1.14.1 xarray 2024.1.1 yarl 1.9.4 zarr 2.16.1 zict 3.0.0 zipp 3.17.0 ```

Tests also pass with everything the same except: neuralprocesses 0.2.6, plum-dispatch 2.4.1

(NB I notice that tensorflow does not support python 3.8.)

On the CI run you linked above @tom-andersson, this is actually run with Python 3.10.12 (not 3.8, relates to #116) tensorflow==2.16.1, NeuralProcesses==0.2.5, plum-dispatch==2.3.6 - which is replicated locally as well.

Haven't quite got to the bottom of this issue but just putting this here for now.

davidwilby commented 3 months ago

Having a bit more of a play with this now I can replicate locally and the test_active_learning tests only fail when run in the same test session as other tests, not when only that one module is run on its own.

davidwilby commented 3 months ago

There seems to be a lot of weird behaviour in the tests when running different combinations of the tests in the same session, not just in test_active_learning but elsewhere as well, leading me to believe there was some undesirable pollution between the different test modules for some reason. In this case, resulting in the mix of torch and tensorflow types.

117 should fix this if the behaviour I've seen locally is replicated.

I've no idea why this has only recently become a problem, however..

tom-andersson commented 2 months ago

@davidwilby fixed this in #117 :-)

Perhaps import deepsensor.torch in test_model wasn't playing well with import deepsensor.tensorflow in test_active_learning, triggering backends to complain about mixed types, which is now fixed by using setUpClass in test classes. I don't fully understand why test leakage would start becoming a problem seemingly "out of nowhere" though. Oh well, it's fixed now.