decile-team / cords

Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order of magnitude using coresets and data selection.
https://cords.readthedocs.io/en/latest/
MIT License
316 stars 53 forks source link

Logistic Regression support for Gradmatch #32

Closed nlokeshiisc closed 3 years ago

nlokeshiisc commented 3 years ago

Logistic Regression model throws errors when we do back propagation. The fix for this is perhaps making freeze=False in forward function of utils/models/logreg_net.py

krishnatejakk commented 3 years ago

@nlokeshiisc Hello Lokesh, the freeze option is already set to False. Can you let me know more details on where exactly you are facing this error. This would help me analyze the error better.

nlokeshiisc commented 3 years ago

@krishnatejakk, This error occurs in specific when we try running GradMatch using Logistic Regression model.

For all the models that are currently supported (eg., CIFARNet), in forward() function there is a freeze argument. When freeze is True, the code makes requires_grad=False all layer except the last layer.

However, for Logistic regression model (logreg_net.py), since there is only one layer, the code seems to freeze entire model when freeze=True in forward() function.

I hit the error when running gradmatch strategy at line 109 in the file dataselectionstrategy.py. Here is few lines of code around the error.

out, l1 = self.model(inputs, last=True, freeze=True)
losses = self.loss(out, targets)
loss = losses.sum()
l0_grads = torch.autograd.grad(loss, out)[0]

In the first line, we ran forward() with freeze=True. This in turn freezes the entire model and thus we cannot compute gradients on loss.

Possible Fix

Maybe we can force freeze=False always for logistic regression model in logreg_net.py

ganramkr commented 3 years ago

@krishnatejakk - I have already discussed this with @durgas16 .

@durgas16 please do close this since what @nlokeshiisc is asking for is easily doable.

krishnatejakk commented 3 years ago

@ganramkr @nlokeshiisc - I have modified the logistic regression code error that is occurring due to freeze option. It should not raise any error now. Thanks for pointing this out :)