ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)
https://rl4.co
MIT License
451 stars 83 forks source link

[BUG] Considerable decrease in policy performance after PyTorch 2.5.0 update #228

Open alopezrivera opened 1 month ago

alopezrivera commented 1 month ago

Describe the bug

I have observed a considerable decrease in policy performance after the recent PyTorch 2.5.0 update. The decrease in performance replicates when training with A2C, REINFORCE and PPO.

Before: brown. After: purple. Same environment model, same random seeds. image

To Reproduce

Install RL4CO and other dependencies using the following Conda environment.yaml:

name: rl
channels:
  - conda-forge
  - defaults
dependencies:
  - pip
  - python=3.12.7
  - pip:
      - rl4co
      # data analysis
      - polars
      - pandas
      # data visualization
      - matplotlib
      - seaborn
      # logging
      - tensorboard

Previous result when creating the environment

Approximately 3 days ago this would've installed the following dependencies:

INSTALLED VERSIONS
-------------------------------------
            rl4co : 0.5.0
            torch : 2.4.1+cu121
        lightning : 2.4.0
          torchrl : 0.5.0
       tensordict : 0.5.0
            numpy : 2.1.2
pytorch_geometric : Not installed
       hydra-core : 1.3.2
        omegaconf : 2.3.0
       matplotlib : 3.9.2
           Python : 3.12.7
         Platform : Linux-5.15.0-78-generic-x86_64-with-glibc2.35
 Lightning device : cuda

This environment can be replicated with the following environment.yaml:

name: rl
channels:
  - conda-forge
  - defaults
dependencies:
  - python=3.12.7
  - pip
  - pip:
    - -r requirements.txt

where requirements.txt must be stored in the same directory as environment.yaml and contain:

setuptools==75.1.0
wheel==0.44.0
pip==24.2
pytz==2024.2
mpmath==1.3.0
antlr4-python3-runtime==4.9.3
urllib3==2.2.3
tzdata==2024.2
typing_extensions==4.12.2
tqdm==4.66.5
tensorboard-data-server==0.7.2
sympy==1.13.3
smmap==5.0.1
six==1.16.0
setproctitle==1.3.3
PyYAML==6.0.2
python-dotenv==1.0.1
pyparsing==3.2.0
Pygments==2.18.0
psutil==6.0.0
protobuf==5.28.2
propcache==0.2.0
polars==1.9.0
platformdirs==4.3.6
pillow==11.0.0
packaging==24.1
orjson==3.10.7
nvidia-nvtx-cu12==12.1.105
nvidia-nvjitlink-cu12==12.6.77
nvidia-nccl-cu12==2.20.5
nvidia-curand-cu12==10.3.2.106
nvidia-cufft-cu12==11.0.2.54
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cublas-cu12==12.1.3.1
numpy==2.1.2
networkx==3.4.1
multidict==6.1.0
mdurl==0.1.2
MarkupSafe==3.0.1
Markdown==3.7
kiwisolver==1.4.7
idna==3.10
grpcio==1.67.0
fsspec==2024.9.0
frozenlist==1.4.1
fonttools==4.54.1
filelock==3.16.1
einops==0.8.0
cycler==0.12.1
colorlog==6.8.2
cloudpickle==3.1.0
click==8.1.7
charset-normalizer==3.4.0
certifi==2024.8.30
attrs==24.2.0
aiohappyeyeballs==2.4.3
absl-py==2.1.0
yarl==1.15.4
Werkzeug==3.0.4
triton==3.0.0
sentry-sdk==2.17.0
scipy==1.14.1
requests==2.32.3
python-dateutil==2.9.0.post0
pyrootutils==1.0.4
omegaconf==2.3.0
nvidia-cusparse-cu12==12.1.0.106
nvidia-cudnn-cu12==9.1.0.70
markdown-it-py==3.0.0
lightning-utilities==0.11.8
Jinja2==3.1.4
gitdb==4.0.11
docker-pycreds==0.4.0
contourpy==1.3.0
aiosignal==1.3.1
tensorboard==2.18.0
robust-downloader==0.0.2
rich==13.9.2
pandas==2.2.3
nvidia-cusolver-cu12==11.4.5.107
matplotlib==3.9.2
hydra-core==1.3.2
GitPython==3.1.43
aiohttp==3.10.10
wandb==0.18.3
torch==2.4.1
seaborn==0.13.2
hydra-colorlog==1.2.0
torchmetrics==1.4.3
tensordict==0.5.0
torchrl==0.5.0
pytorch-lightning==2.4.0
lightning==2.4.0
rl4co==0.5.0
pyDOE3==1.0.4
statsmodels==0.14.4

Current result when creating the environment

As of today it installs the following dependencies, including PyTorch 2.5.0:

INSTALLED VERSIONS
-------------------------------------
            rl4co : 0.5.0
            torch : 2.5.0+cu124
        lightning : 2.4.0
          torchrl : 0.5.0
       tensordict : 0.5.0
            numpy : 1.26.4
pytorch_geometric : Not installed
       hydra-core : 1.3.2
        omegaconf : 2.3.0
       matplotlib : 3.9.2
           Python : 3.12.7
         Platform : Linux-6.5.0-35-generic-x86_64-with-glibc2.35
 Lightning device : cuda

Detailed list of dependencies

The following is a detailed list of all different dependencies between the environment created 3 days ago and the current one. I believe PyTorch 2.5.0 is the main culprit here.

Library Version in File 1 Version in File 2
PyYAML 6.0.2 6.0.2
Pygments 2.18.0 (missing)
absl-py 2.1.0 2.1.0
aiohappyeyeballs 2.4.3 2.4.3
aiohttp 3.10.10 3.10.10
aiosignal 1.3.1 1.3.1
antlr4-python3-runtime 4.9.3 4.9.3
attrs 24.2.0 24.2.0
certifi 2024.8.30 2024.8.30
charset-normalizer 3.4.0 3.4.0
click 8.1.7 8.1.7
cloudpickle 3.1.0 3.1.0
colorlog 6.8.2 6.8.2
contourpy 1.3.0 1.3.0
cycler 0.12.1 0.12.1
docker-pycreds 0.4.0 0.4.0
einops 0.8.0 0.8.0
filelock 3.16.1 3.16.1
fonttools 4.54.1 4.54.1
frozenlist 1.4.1 1.4.1
fsspec 2024.9.0 2024.10.0
gitdb 4.0.11 4.0.11
GitPython 3.1.43 3.1.43
grpcio 1.67.0 1.67.0
hydra-colorlog 1.2.0 1.2.0
hydra-core 1.3.2 1.3.2
idna 3.10 3.10
Jinja2 3.1.4 3.1.4
kiwisolver 1.4.7 1.4.7
lightning 2.4.0 2.4.0
lightning-utilities 0.11.8 0.11.8
Markdown 3.7 3.7
markdown-it-py 3.0.0 3.0.0
MarkupSafe 3.0.1 3.0.2
matplotlib 3.9.2 3.9.2
mdurl 0.1.2 0.1.2
mpmath 1.3.0 1.3.0
multidict 6.1.0 6.1.0
networkx 3.4.1 3.4.1
numpy 2.1.2 1.26.4
nvidia-cublas-cu12 12.1.3.1 12.4.5.8
nvidia-cuda-cupti-cu12 12.1.105 12.4.127
nvidia-cuda-nvrtc-cu12 12.1.105 12.4.127
nvidia-cuda-runtime-cu12 12.1.105 12.4.127
nvidia-cudnn-cu12 9.1.0.70 9.1.0.70
nvidia-cufft-cu12 11.0.2.54 11.2.1.3
nvidia-curand-cu12 10.3.2.106 10.3.5.147
nvidia-cusolver-cu12 11.4.5.107 11.6.1.9
nvidia-cusparse-cu12 12.1.0.106 12.3.1.170
nvidia-nccl-cu12 2.20.5 2.21.5
nvidia-nvjitlink-cu12 12.6.77 12.4.127
nvidia-nvtx-cu12 12.1.105 12.4.127
omegaconf 2.3.0 2.3.0
orjson 3.10.7 3.10.9
packaging 24.1 24.1
pandas 2.2.3 2.2.3
patsy (missing) 0.5.6
pillow 11.0.0 11.0.0
platformdirs 4.3.6 4.3.6
polars 1.9.0 1.10.0
propcache 0.2.0 0.2.0
protobuf 5.28.2 5.28.2
psutil 6.0.0 6.1.0
pyDOE3 1.0.4 1.0.4
pyparsing 3.2.0 3.2.0
pyrootutils 1.0.4 1.0.4
python-dateutil 2.9.0.post0 2.9.0.post0
python-dotenv 1.0.1 1.0.1
pytorch-lightning 2.4.0 2.4.0
pytz 2024.2 2024.2
requests 2.32.3 2.32.3
rich 13.9.2 13.9.2
rl4co 0.5.0 0.5.0
robust-downloader 0.0.2 0.0.2
scipy 1.14.1 1.14.1
seaborn 0.13.2 0.13.2
sentry-sdk 2.17.0 2.17.0
setproctitle 1.3.3 1.3.3
setuptools 75.1.0 75.1.0
six 1.16.0 1.16.0
smmap 5.0.1 5.0.1
statsmodels 0.14.4 0.14.4
sympy 1.13.3 1.13.1
tensorboard 2.18.0 2.18.0
tensorboard-data-server 0.7.2 0.7.2
tensordict 0.5.0 0.5.0
torch 2.4.1 2.5.0
torchmetrics 1.4.3 1.5.0
torchrl 0.5.0 0.5.0
tqdm 4.66.5 4.66.5
triton 3.0.0 3.1.0
typing_extensions 4.12.2 4.12.2
tzdata 2024.2 2024.2
urllib3 2.2.3 2.2.3
wandb 0.18.3 0.18.5
Werkzeug 3.0.4 3.0.4
wheel 0.44.0 0.44.0
yarl 1.15.4 1.15.5

System info

NVIDIA L40 system pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22

Reason and Possible fixes

No idea as to the reason. A temporary fix could be to lock the PyTorch version required by RL4CO to PyTorch 2.4.1.

Checklist

fedebotu commented 1 month ago

Hi @alopezrivera , thanks for reporting this bug, as it may be critical. Here are a few follow-up questions to narrow down the problem before checking on our end:

  1. Which environment and model did you use?
  2. Does the bug appear if you train on FP-32 instead? This might be due to new precision settings in PyTorch 2.5
fedebotu commented 1 month ago

Hey @alopezrivera, I could not reproduce your results on my end. I am running the quickstart notebook with both the previous version and the newer one, and get the same result.

Output of from rl4co.utils.utils import show_versions; show_versions() :

INSTALLED VERSIONS
-------------------------------------
            rl4co : 0.5.0
            torch : 2.5.0+cu124
        lightning : 2.2.1
          torchrl : 2024.10.21
       tensordict : 2024.10.21
            numpy : 1.26.4
pytorch_geometric : Not installed
       hydra-core : 1.3.2
        omegaconf : 2.3.0
       matplotlib : 3.8.3
           Python : 3.11.10
         Platform : Linux-6.8.0-47-generic-x86_64-with-glibc2.39
 Lightning device : cuda

PS: note that due to several recent updates to PyTorch, there may be some incompatibilities with torchrl and tensordict. The latter was updated few hours ago and is causing some issues (see #229 ) because TorchRL was not upgraded to the latest version yet ( I believe torchrl is going to be updated soon as you can see here).

I installed the nightly versions with the following command:

pip3 install torchrl-nightly
fedebotu commented 1 month ago

@alopezrivera most recent TorchRL and Tensordict have been released, still cannot reproduce the bug. When you become available, please let us know how to reproduce the result!

alopezrivera commented 1 month ago

In that case it is highly likely that this had to do with either numerical precision settings or some operation conducted inside my space vehicle routing environment. I'll look into those two possibilities asap and try to construct a minimal environment that reproduces the issue.

fedebotu commented 1 month ago

Good, also you may try testing it in another system. For example if the problem is reproducible in Google Colab, then it is likely an important issue in RL4CO. Otherwise it might be outside of our control

alopezrivera commented 1 month ago

So far I was able to reproduce this behavior on the 4 systems below (all on RunPod), which makes me think the environment is the culprit here

fedebotu commented 1 month ago

[!IMPORTANT] See possible fixes at the end of this message!

I tried manually updating all dependencies, and it turns out that, on a 3090, the bug may be reproducible. In my case, the loss explodes during the Epoch 2 on this notebook :thinking:

Here is my current env (note that you may install the dev version of rl4co from source)

INSTALLED VERSIONS
-------------------------------------
            rl4co : 0.5.1dev0
            torch : 2.5.0+cu124
        lightning : 2.4.0
          torchrl : 0.6.0
       tensordict : 0.6.0
            numpy : 1.26.4
pytorch_geometric : 2.6.1
       hydra-core : 1.3.2
        omegaconf : 2.3.0
       matplotlib : 3.9.2
           Python : 3.11.10
         Platform : Linux-6.8.0-47-generic-x86_64-with-glibc2.39
 Lightning device : cuda

Possible fixes

So far I have managed to fix the bug (temporarily) in the following ways:

1. Reduce precision

If precision is increased to "32", then there is no more bug (i.e. pass to the RL4COTrainer this: precision="32". However, this is suboptimal since it increases training time.

2. Install a previous PyTorch version <2.4.0

If you install, say, with pip install torch==2.3.0, then the training works as expected.

3. Use another GPU

Also suboptimal, as it seems newer GPUs are affected. Training in Google Colab did not yield this result. I strongly suspect it is due to precision issues and implementation of say SDPA implementations on certain devices and precision settings


The above are still temporary fixes, and I am not 100% sure why they happen as the logic did not change - so the problem should be outside of RL4CO but somehow we may need to adapt. Most likely, it is due to a precision error - why that is, that's the question. Precision is handled by PyTorch Lightning, so one option is to check out for updates on their side. Another option is to dig deep into the SDPA and see by changing that to the manual implementation / FlashAttention repo whether the problem persists

CC: @cbhua


UPDATE 1

If the sdpa_kernel is changed, then the error does not appear. For example, you can try setting the context for the model fit as :

with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
    trainer.fit(model)

this "simple trick" seems to work well for me (at times only though - pretty weird), indicating the direction of SDPA / precision setting was indeed correct


UPDATE 2 (narrowed down the issue and possible fix!)

I think we finally found the main culprit. The issue appears to be in the scaled_dot_product_attention of PyTorch under certain conditions:

  1. With certain sdpa_kernels
  2. At fp16 precision (which is the default in RL4CO)
  3. When an attn_mask is provided
  4. PyTorch version >= 2.4.0
  5. Not all devices are affected

The easiest fix for affected devices is to simply change the sdpa_fn at this line:

-  self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention
+  self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention_simple

i.e. replacing the SDPA implementation with our own. This appears to solve the issue without changing too much code!

Minor note: for single-query attention (as in AM), this appears to speed up the performance a little, while in multi-query (such as multi-start), it seems to be slightly slower, i.e., in POMO.

Reasons why this happens

There can be a few. What I suspect is that there is some wrong conversion in PyTorch between fp32 -> fp16 and vice versa, for instance, for the -inf case of the mask (which reports nan at a time for me). Might be related to this PR and this release blog, but I don't have time/knowledge to go through the C code to check it. If this is the case, we might want to file an issue directly to PyTorch.

Reproducing the problem

The snippet below should reproduce the issue in your case. Normally, the two implementations should be exactly the same:

from torch.nn.functional import scaled_dot_product_attention
from rl4co.models.nn.attention import scaled_dot_product_attention_simple

# Make some random data
bs, n, d = 32, 100, 128
q = torch.rand(bs, n, d)
k = torch.rand(bs, n, d)
v = torch.rand(bs, n, d)
attn_mask = torch.rand(bs, n, n) < 0.1

# to float16
q = q.half().cuda()
k = k.half().cuda()
v = v.half().cuda()
attn_mask = attn_mask.cuda()

# Run the two implementations
with torch.amp.autocast("cuda"):
    out_pytorch = scaled_dot_product_attention(q, k, v, attn_mask)
    out_ours = scaled_dot_product_attention_simple(q, k, v, attn_mask)

# If the two outputs are not close, print the maximum difference
if not torch.allclose(out_ours, out_pytorch, atol=1e-3):
    raise ValueError(f"Outputs are not close. Max diff: {torch.max(torch.abs(out_ours - out_pytorch))}")
fedebotu commented 4 weeks ago

@alopezrivera in the new RL4CO version, you may replace the decoder's scaled dot product attention implementation (SDPA) by passing simple as a hyperparameter:

policy = AttentionModelPolicy(env_name=env.name, sdpa_fn_decoder="simple")