gfxdisp / mdf

Multi-scale discriminator feature-wise loss function
BSD 3-Clause "New" or "Revised" License
104 stars 8 forks source link

High memory usage problem #9

Closed delyan-boychev closed 11 months ago

delyan-boychev commented 1 year ago
          I have already set the require_grads to False and used torch.no_grad() for the target - for instance y, but the main idea is that it has to compute the grads with respect to each of the layers and inputs by applying the chain rule. So for the task, we have 8 backpropagations at a time - 8 different models. It creates 8 deep copies of the input. I am interested in how you measure the inference and backpropagation time and the memory overhead.

Best regards, Delyan

Originally posted by @delyan-boychev in https://github.com/gfxdisp/mdf/issues/8#issuecomment-1676796314