ROCm / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
17 stars 14 forks source link

Is RoCm apex.amp deprecated & behavior mismatch vs NVIDIA APEX #118

Open fxmarty opened 11 months ago

fxmarty commented 11 months ago

Hi, I am wondering if RoCm apex.amp is deprecated? NVIDIA APEX has some deprecation warnings that are not present in this repo: https://github.com/NVIDIA/apex/pull/1506/files

Moreover, I realize that this code

import torch
import torch.nn as nn
from apex import amp

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(3, 4)

    def forward(self, attn_probs, value_states):
        attn_output = torch.bmm(attn_probs, value_states)
        return attn_output

from torch.optim import AdamW

model = MyModule().to("cuda")
optimizer = AdamW(model.parameters())

model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

attn_probs = torch.rand(4, 16, 16).to("cuda")
value_states = torch.rand(4, 16, 2).to(torch.float16).to("cuda")

attn_output = model(attn_probs, value_states)

runs fine with NVIDA APEX but fails on RoCm APEX with the following log:

Traceback (most recent call last):
  File "run_bmm.py", line 26, in <module>
    attn_output = model(attn_probs, value_states)
  File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "run_bmm.py", line 11, in forward
    attn_output = torch.bmm(attn_probs, value_states)
RuntimeError: expected scalar type Half but found Float

However, using torch.cuda.amp.autocast instead works fine for both RoCm and CUDA-powered devices (with torch 2.0.1).

Thank you!

pruthvistony commented 11 months ago

@fxmarty, I believe the problem could be happening due to some missing fix in Adam optimizer handling in ROCm apex. Checking on it will get back.