Closed Chrispresso closed 1 month ago
HI @Chrispresso, thank you for reporting the issue! Could you please add the error output to the description?
@philippmwirth just updated with the error and traceback. Main error though is this:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 11; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Thanks for providing the full error trace! To me it looks like the error happens in torchvision's resnet:
File "<user>/.venv/lib/python3.11/site-packages/torchvision/models/resnet.py", line 97, in forward
out = self.bn2(out)
File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "<user>/.venv/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py", line 175, in forward
return F.batch_norm(
File "<user>/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2509, in batch_norm
return torch.batch_norm(
Can you try a minimal example using only torchvision resnet (no Lightly prediction or projection heads) and with a dummy loss function instead of DINO?
Not sure if related but it seems there's a small bug in your code: You cancel the last layer gradients twice.
model.on_after_backward()
# We only cancel gradients of student head.
model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
Accidentally cancelling gradients twice wasn't the issue. Changed that and still see the problem which makes sense since it's only removing the need for a gradient update.
I've swapped code to use just resnet18 and don't see a problem with 2 devices. The following code runs without a problem:
from lightning.fabric import Fabric
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
fabric = Fabric(accelerator='cuda', num_nodes=1, devices=2)
fabric.launch()
torch.autograd.set_detect_anomaly(True)
input_dim = 512
model = torchvision.models.resnet18()
transform = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop((96, 96)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: torch.cat([x, x, x], 1))
])
dataset = torchvision.datasets.VOCDetection(
"./data",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=8,
)
dataloader = fabric.setup_dataloaders(dataloader)
def criterion(yhat: torch.Tensor, _ignore):
ones = torch.ones(yhat.shape, device=yhat.device)
return F.mse_loss(yhat, ones)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)
epochs = 10
print("Starting Training")
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
X = batch[0]
fabric.barrier()
pred = model(X)
loss = criterion(pred, X)
total_loss += loss.detach()
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
If I swap the model to this:
class DINO(L.LightningModule):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512
# instead of a resnet you can also use a vision transformer backbone as in the
# original paper (you might have to reduce the batch size in this case):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
input_dim = backbone.embed_dim
self.student_backbone = backbone
self.student_head = DINOProjectionHead(
input_dim, 512, 64, 2048, freeze_last_layer=1
)
self.teacher_backbone = copy.deepcopy(backbone)
self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
deactivate_requires_grad(self.teacher_backbone)
deactivate_requires_grad(self.teacher_head)
self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)
def forward(self, x, teacher: bool = False):
if teacher:
y = self.student_backbone(x).flatten(start_dim=1)
z = self.student_head(y)
else:
y = self.teacher_backbone(x).flatten(start_dim=1)
z = self.teacher_head(y)
return z
def on_after_backward(self):
self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)
which is just changing out the backbone from resnet18 to dino_vits16, then it works. So it seems like the issue is a combination of using resnet18 with fabric through DINO.
Curious, it could be due to the combination of multiple forward passes, batch norm, and distributed fabric. Did you already search for similar issues in https://github.com/Lightning-AI/pytorch-lightning/issues?
There wasn't anything there from what I could tell. I did track down this comment. I figured maybe the internal part of batch norm is referenced as a buffer (similar to DINOLoss). Took a shot and changed the below code:
from lightning.fabric.strategies.ddp import DDPStrategy
fabric = Fabric(accelerator='cuda', num_nodes=1, devices=2, strategy=DDPStrategy(broadcast_buffers=False))
This now works for training with resnet18 backbone across multiple devices. The issue seems to be that broadcast_buffers acts as an in-place operation. So only when you do multiple calls to forward do you see this. By explicitly using a DDPStrategy where broadcast_buffers=False, I seem to be able to get around this problem. So can probably close this
Thank you for figuring out the workaround! I will close this issue now. Feel free to reopen if anything else comes up.
Based on this tutorial if you use PyTorch Fabric for distributed training it will fail during the backward pass when using more than 1 GPU.
Tested with PyTorch Lightning multiple GPU + DINO = works. Tested with PyTorch Fabric single GPU + DINO = works. Tested with PyTorch Fabric multiple GPU + DINO = fails.
Repro:
Below is the error output when switching devices from 1 to 2: