nshepperd / jax-guided-diffusion

jax version of clip guided diffusion scripts
89 stars 24 forks source link

Preserving this for historical reasons #5

Open andybak opened 10 months ago

andybak commented 10 months ago

I've got a personal interest as some of my favorite images were generated using this. However - in general I think it's important to preserve these models as they were key stages in the development of generative AI.

The Colab currently fails: https://colab.research.google.com/drive/1Z5kK1WXTkYoMAVN6FqkQg0Fa4bE5BnxG?usp=sharing with "jax is not a supported wheel on this platform"

and @SoftologyPro failed to fix this for Visions of Chaos: https://softology.pro/tutorials/tensorflow/tensorflow.htm (which is a incredibly valuable repository for AI models generally)

SoftologyPro commented 10 months ago

I did manage to get it running in Visions of Chaos again, but something since the original release is causing it to run very slow. Approx 40 MINUTES per image on a 4090. GPU is maxed out and CPU usage is minimal. There are no warnings or errors to point to possible causes, but I am assuming that some Python package upgrade outside our control changed to cause the problem. https://softologyblog.wordpress.com/2023/10/10/a-plea-to-all-python-developers/ This also shows how even if you think you have a self-contained notebook with your script inside it, updates outside your control can cause the script to fail and not run any more. Without trying random versions in pip installs I don't know how we can work out the slowdown cause. The two "Huemin" variants of JAX CLIP Guiuded Diffusion I have also run super slow now.

SoftologyPro commented 10 months ago

These are the pip commands I use to setup a Python env to run the scripts within. This is the env that did run at a decent speed before when these JAX scripts were originally released.

python -m pip install --upgrade pip
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts wheel==0.38.4
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts pillow==9.2.0
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts opencv-python==4.5.5.62
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts pandas==1.5.0
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts regex==2021.3.17
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts einops==0.4.1
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts ipython==7.23.0
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts requests==2.25.1
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts ftfy==6.0.1
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts braceexpand==0.1.7
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts jaxlib-0.3.2+cuda11.cudnn82-cp310-none-win_amd64.whl
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts jax==0.3.2
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts cbor2==5.4.2.post1
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts dm-haiku==0.0.5
pip uninstall -y pytorch-lightning
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts pytorch-lightning==2.1.0
pip uninstall -y torchmetrics
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts torchmetrics==0.7.2
pip uninstall -y torch
pip uninstall -y torch
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts librosa==0.8.1
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts numpy==1.23.4
pip uninstall -y setuptools
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts setuptools==59.5.0
pip uninstall -y pyparsing
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts pyparsing==2.4.7

As a test I removed all version numbers except with jax, haiku and getting torch to v2.0.1. No difference in speed, so latest packages do not seem to help the speed.

python -m pip install --upgrade pip
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts wheel
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts pillow
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts opencv-python
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts pandas
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts regex
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts einops
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts ipython
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts requests
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts ftfy
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts braceexpand
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts jaxlib-0.3.2+cuda11.cudnn82-cp310-none-win_amd64.whl
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts jax==0.3.2
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts cbor2
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts dm-haiku==0.0.5
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts tqdm
pip install --no-cache-dir --ignore-installed --force-reinstall --no-warn-conflicts torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118

If we could get an updated requirements.txt with tested fast packages that would help.

SoftologyPro commented 10 months ago

Updating jax and jaxlib to 0.3.7 (I had a whl for that version) does not help. Gives new errors.

\lib\site-packages\jax\_src\prng.py", line 477, in threefry_2x32
    except core.InconclusiveDimensionOperation as e:
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: DenseElementsAttr could not be constructed from the given buffer. This may mean that the Python buffer layout does not match that MLIR expected layout and is a bug.