Open Drxdx opened 1 year ago
Hi @Drxdx could you share the code that you're trying to run?
model = models.resnet18(num_classes=10)
print(type(model)) errors = ModuleValidator.validate(model, strict=False) print(errors)
For a simple example,I use resnet18 for DP training, because resent18 contains BN layer, so it is not possible for differential privacy. If we use the ModuleValidator.fix(model, strcit = False) function, it will change the BN layer to the GN layer, so we can use it.But I use the model with Darts, where the fix() function is useless.
In this case, we have a function GradSampleModule.is_supported(m) that returns True for Conv and False for BN with Resnet18. But with Darts, both Conv and Bn return False.
This problem has been bothering me for a long time. I hope you can help me
Ok, this is actually interesting. I've investigated this a bit, and it seems like this problem could appear for any deserialized model.
The problem is, when the model is loaded by DartsSpace.load_searched_model('darts-v2', ...)
, the class object of its batch norms is different from the class object you get normally. This is confusing, so here's the example:
> model = DartsSpace.load_searched_model('darts-v2', pretrained=True, download=True)
> type(darts_v2_model.stages[0][0].preprocessor.pre0[2])
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
> bn = torch.nn.modules.batchnorm.BatchNorm2d(2)
> type(bn)
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
> type(darts_v2_model.stages[0][0].preprocessor.pre0[2]) == type(bn)
False
> id(type(darts_v2_model.stages[0][0].preprocessor.pre0[2]))
201567168
> id(type(bn))
98172448
This the leads to ModelValidator
ignoring BatchNorms, because it checks the class object, not string representation.
I'm not exactly sure what about serialization process makes is create new class instances, and not sure how commonplace this is.
However, I don't see a good reason why we should keep references as keys, not strings. Any ideas why switching to strings could backfire?
cc @alexandresablayrolles @karthikprasad
The same problem in the First layer conv2d. Resnet18 is work, but the darts model is not work. So I don't know how to fix this problem. Maybe the darts is not compatible with opacus.
📚 Documentation
i want to use dp with NAS, when i use the pre-trained DARTS model , ModuleValidator.fix() function does't work! anyone meet this problem?