gfxdisp / mdf

Multi-scale discriminator feature-wise loss function
BSD 3-Clause "New" or "Revised" License
102 stars 8 forks source link

High memory consumption, inference and backpropagation time #8

Closed delyan-boychev closed 11 months ago

delyan-boychev commented 11 months ago

While I was testing the MDF loss, I found that the loss function uses much more memory than what is proposed in the paper for only one image. What might be the issue?

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
target = torch.randn(1, 3, 256, 256).to(device).requires_grad_(False)
input = torch.randn(1, 3, 256, 256).to(device).requires_grad_(True)
mdf = MDFLoss("./Ds_Denoising.pth", True)
loss = torch.zeros(1).cuda()
loss_style = torch.zeros(1).cuda()
j = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
a = torch.cuda.memory_reserved(0)
t = mdf(input, target)
t.backward()
b = torch.cuda.memory_reserved(0)
start.record()
for i in range(100):
    t = mdf(input, target)
    t.backward()
end.record()

# Waits for everything to finish running
torch.cuda.synchronize()

print(str(start.elapsed_time(end)/100) + "ms")
print(str((b-a)/1e6) + "MB")

The output of the code above is placed here:

236.59349609375ms
1369.440256MB

Torch version: 2.0.1+cu117 Python version: 3.10.6 OS: Ubuntu 22.04.2 LTS x86_64 Kernel: 5.15.0-76-generic CPU: AMD Ryzen 3 3100 (8) @ 3.600GHz GPU: NVIDIA GeForce GTX 1050 Ti

aamir-mustafa commented 11 months ago

Hi. Thanks for your interest in our work. As you would know, the pre-trained discriminators are to be used as task-specific feature extractors only. The gradients at the time of training should only be computed for the required model and you may want to set gradients for the parameters of all the pre-trained discriminators as False.

Please try the following in 'mdfloss.py'

for param in self.Ds.parameters():
    param.requires_grad = False

Hope this answers your question.

Thanks again Best Aamir

delyan-boychev commented 11 months ago

I have already set the require_grads to False and used torch.no_grad() for the target - for instance y, but the main idea is that it has to compute the grads with respect to each of the layers and inputs by applying the chain rule. So for the task, we have 8 backpropagations at a time - 8 different models. It creates 8 deep copies of the input. I am interested in how you measure the inference and backpropagation time and the memory overhead.

Best regards, Delyan