Open alopezrivera opened 1 month ago
Hi @alopezrivera , thanks for reporting this bug, as it may be critical. Here are a few follow-up questions to narrow down the problem before checking on our end:
FP-32
instead? This might be due to new precision settings in PyTorch 2.5 Hey @alopezrivera, I could not reproduce your results on my end. I am running the quickstart notebook with both the previous version and the newer one, and get the same result.
Output of from rl4co.utils.utils import show_versions; show_versions()
:
INSTALLED VERSIONS
-------------------------------------
rl4co : 0.5.0
torch : 2.5.0+cu124
lightning : 2.2.1
torchrl : 2024.10.21
tensordict : 2024.10.21
numpy : 1.26.4
pytorch_geometric : Not installed
hydra-core : 1.3.2
omegaconf : 2.3.0
matplotlib : 3.8.3
Python : 3.11.10
Platform : Linux-6.8.0-47-generic-x86_64-with-glibc2.39
Lightning device : cuda
PS: note that due to several recent updates to PyTorch, there may be some incompatibilities with torchrl
and tensordict
. The latter was updated few hours ago and is causing some issues (see #229 ) because TorchRL was not upgraded to the latest version yet ( I believe torchrl
is going to be updated soon as you can see here).
I installed the nightly versions with the following command:
pip3 install torchrl-nightly
@alopezrivera most recent TorchRL and Tensordict have been released, still cannot reproduce the bug. When you become available, please let us know how to reproduce the result!
In that case it is highly likely that this had to do with either numerical precision settings or some operation conducted inside my space vehicle routing environment. I'll look into those two possibilities asap and try to construct a minimal environment that reproduces the issue.
Good, also you may try testing it in another system. For example if the problem is reproducible in Google Colab, then it is likely an important issue in RL4CO. Otherwise it might be outside of our control
So far I was able to reproduce this behavior on the 4 systems below (all on RunPod), which makes me think the environment is the culprit here
pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22
pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22
pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22
pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22
[!IMPORTANT] See possible fixes at the end of this message!
I tried manually updating all dependencies, and it turns out that, on a 3090, the bug may be reproducible. In my case, the loss explodes during the Epoch 2 on this notebook :thinking:
Here is my current env (note that you may install the dev version of rl4co
from source)
INSTALLED VERSIONS
-------------------------------------
rl4co : 0.5.1dev0
torch : 2.5.0+cu124
lightning : 2.4.0
torchrl : 0.6.0
tensordict : 0.6.0
numpy : 1.26.4
pytorch_geometric : 2.6.1
hydra-core : 1.3.2
omegaconf : 2.3.0
matplotlib : 3.9.2
Python : 3.11.10
Platform : Linux-6.8.0-47-generic-x86_64-with-glibc2.39
Lightning device : cuda
So far I have managed to fix the bug (temporarily) in the following ways:
1. Reduce precision
If precision is increased to "32", then there is no more bug (i.e. pass to the RL4COTrainer
this: precision="32"
. However, this is suboptimal since it increases training time.
2. Install a previous PyTorch version <2.4.0
If you install, say, with pip install torch==2.3.0
, then the training works as expected.
3. Use another GPU
Also suboptimal, as it seems newer GPUs are affected. Training in Google Colab did not yield this result. I strongly suspect it is due to precision issues and implementation of say SDPA implementations on certain devices and precision settings
The above are still temporary fixes, and I am not 100% sure why they happen as the logic did not change - so the problem should be outside of RL4CO but somehow we may need to adapt. Most likely, it is due to a precision error - why that is, that's the question. Precision is handled by PyTorch Lightning, so one option is to check out for updates on their side. Another option is to dig deep into the SDPA and see by changing that to the manual implementation / FlashAttention repo whether the problem persists
CC: @cbhua
UPDATE 1
If the sdpa_kernel
is changed, then the error does not appear. For example, you can try setting the context for the model fit as :
with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
trainer.fit(model)
this "simple trick" seems to work well for me (at times only though - pretty weird), indicating the direction of SDPA / precision setting was indeed correct
UPDATE 2 (narrowed down the issue and possible fix!)
I think we finally found the main culprit. The issue appears to be in the scaled_dot_product_attention
of PyTorch under certain conditions:
sdpa_kernel
sfp16
precision (which is the default in RL4CO)attn_mask
is providedThe easiest fix for affected devices is to simply change the sdpa_fn
at this line:
- self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention
+ self.sdpa_fn = sdpa_fn if sdpa_fn is not None else scaled_dot_product_attention_simple
i.e. replacing the SDPA implementation with our own. This appears to solve the issue without changing too much code!
Minor note: for single-query attention (as in AM), this appears to speed up the performance a little, while in multi-query (such as multi-start), it seems to be slightly slower, i.e., in POMO.
There can be a few. What I suspect is that there is some wrong conversion in PyTorch between fp32 -> fp16 and vice versa, for instance, for the -inf
case of the mask (which reports nan
at a time for me). Might be related to this PR and this release blog, but I don't have time/knowledge to go through the C code to check it. If this is the case, we might want to file an issue directly to PyTorch.
The snippet below should reproduce the issue in your case. Normally, the two implementations should be exactly the same:
from torch.nn.functional import scaled_dot_product_attention
from rl4co.models.nn.attention import scaled_dot_product_attention_simple
# Make some random data
bs, n, d = 32, 100, 128
q = torch.rand(bs, n, d)
k = torch.rand(bs, n, d)
v = torch.rand(bs, n, d)
attn_mask = torch.rand(bs, n, n) < 0.1
# to float16
q = q.half().cuda()
k = k.half().cuda()
v = v.half().cuda()
attn_mask = attn_mask.cuda()
# Run the two implementations
with torch.amp.autocast("cuda"):
out_pytorch = scaled_dot_product_attention(q, k, v, attn_mask)
out_ours = scaled_dot_product_attention_simple(q, k, v, attn_mask)
# If the two outputs are not close, print the maximum difference
if not torch.allclose(out_ours, out_pytorch, atol=1e-3):
raise ValueError(f"Outputs are not close. Max diff: {torch.max(torch.abs(out_ours - out_pytorch))}")
@alopezrivera in the new RL4CO version, you may replace the decoder's scaled dot product attention implementation (SDPA) by passing simple
as a hyperparameter:
policy = AttentionModelPolicy(env_name=env.name, sdpa_fn_decoder="simple")
Describe the bug
I have observed a considerable decrease in policy performance after the recent PyTorch 2.5.0 update. The decrease in performance replicates when training with A2C, REINFORCE and PPO.
Before: brown. After: purple. Same environment model, same random seeds.
To Reproduce
Install RL4CO and other dependencies using the following Conda
environment.yaml
:Previous result when creating the environment
Approximately 3 days ago this would've installed the following dependencies:
This environment can be replicated with the following
environment.yaml
:where
requirements.txt
must be stored in the same directory asenvironment.yaml
and contain:Current result when creating the environment
As of today it installs the following dependencies, including PyTorch 2.5.0:
Detailed list of dependencies
The following is a detailed list of all different dependencies between the environment created 3 days ago and the current one. I believe PyTorch 2.5.0 is the main culprit here.
System info
NVIDIA L40 system
pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22
Reason and Possible fixes
No idea as to the reason. A temporary fix could be to lock the PyTorch version required by RL4CO to PyTorch 2.4.1.
Checklist