EmptyJackson / policy-guided-diffusion

Official implementation of the RLC 2024 paper "Policy-Guided Diffusion"
MIT License
117 stars 7 forks source link

module 'scipy.linalg' has no attribute 'tril' #4

Closed daihuiao closed 5 months ago

daihuiao commented 5 months ago

I think I encountered some environmental installation problems. Even using the Docker environment, I still reported an error. Module 'scipy.linalg' has no attribute 'tril' has appeared in the CONDA environment of manual configuration.The detailed log is as follows:

❯ ./run_docker.sh 0 python3.9 train_diffusion.py --log --wandb_project diff --wandb_team flair --dataset_name walker2d-medium-v2 cat: ./docker/wandb_key: 没有那个文件或目录 已经是最新的。 Launching container pgd_0 on GPU 0

========== == CUDA ==

CUDA Version 12.1.0

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License. By pulling and using the container, you accept the terms and conditions of this license: https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.


DEPRECATION NOTICE!


THIS IMAGE IS DEPRECATED and is scheduled for DELETION. https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md

/home/duser/.local/lib/python3.9/site-packages/chex/_src/pytypes.py:54: DeprecationWarning: jax.random.KeyArray is deprecated. Use jax.Array for annotations, and jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of typed prng keys. PRNGKey = jax.random.KeyArray Traceback (most recent call last): File "/home/duser/policy-guided-diffusion/train_diffusion.py", line 11, in from diffusion.diffusion import ( File "/home/duser/policy-guided-diffusion/diffusion/diffusion.py", line 1, in import optax File "/home/duser/.local/lib/python3.9/site-packages/optax/init.py", line 18, in from optax._src.alias import adabelief File "/home/duser/.local/lib/python3.9/site-packages/optax/_src/alias.py", line 25, in from optax._src import factorized File "/home/duser/.local/lib/python3.9/site-packages/optax/_src/factorized.py", line 27, in from optax._src import utils File "/home/duser/.local/lib/python3.9/site-packages/optax/_src/utils.py", line 22, in import jax.scipy.stats.norm as multivariate_normal File "/home/duser/.local/lib/python3.9/site-packages/jax/scipy/stats/init.py", line 40, in from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde File "/home/duser/.local/lib/python3.9/site-packages/jax/_src/scipy/stats/kde.py", line 26, in from jax.scipy import linalg, special File "/home/duser/.local/lib/python3.9/site-packages/jax/scipy/linalg.py", line 18, in from jax._src.scipy.linalg import ( File "/home/duser/.local/lib/python3.9/site-packages/jax/_src/scipy/linalg.py", line 408, in @_wraps(scipy.linalg.tril) AttributeError: module 'scipy.linalg' has no attribute 'tril'

EmptyJackson commented 5 months ago

Thanks for raising this! It's a common Jax issue due to a scipy update. I've fixed the version in the requirements so it should be fixed now!