pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

DP with NAS and meet a ‘Model contains a trainable layer that Opacus doesn't currently support’ #576

Open Drxdx opened 1 year ago

Drxdx commented 1 year ago

📚 Documentation

i want to use dp with NAS, when i use the pre-trained DARTS model , ModuleValidator.fix() function does't work! anyone meet this problem?

lucacorbucci commented 1 year ago

Hi @Drxdx could you share the code that you're trying to run?

Drxdx commented 1 year ago

model = models.resnet18(num_classes=10)

darts_v2_model = DartsSpace.load_searched_model('darts-v2', pretrained=True, download=True)

print("============",type(darts_v2_model))

print(type(model)) errors = ModuleValidator.validate(model, strict=False) print(errors)

For a simple example,I use resnet18 for DP training, because resent18 contains BN layer, so it is not possible for differential privacy. If we use the ModuleValidator.fix(model, strcit = False) function, it will change the BN layer to the GN layer, so we can use it.But I use the model with Darts, where the fix() function is useless.

In this case, we have a function GradSampleModule.is_supported(m) that returns True for Conv and False for BN with Resnet18. But with Darts, both Conv and Bn return False.

This problem has been bothering me for a long time. I hope you can help me

ffuuugor commented 1 year ago

Ok, this is actually interesting. I've investigated this a bit, and it seems like this problem could appear for any deserialized model.

The problem is, when the model is loaded by DartsSpace.load_searched_model('darts-v2', ...), the class object of its batch norms is different from the class object you get normally. This is confusing, so here's the example:

> model = DartsSpace.load_searched_model('darts-v2', pretrained=True, download=True)
> type(darts_v2_model.stages[0][0].preprocessor.pre0[2])
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>

> bn = torch.nn.modules.batchnorm.BatchNorm2d(2)
> type(bn)
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>

> type(darts_v2_model.stages[0][0].preprocessor.pre0[2]) == type(bn)
False

> id(type(darts_v2_model.stages[0][0].preprocessor.pre0[2]))
201567168

> id(type(bn))
98172448

This the leads to ModelValidator ignoring BatchNorms, because it checks the class object, not string representation.

I'm not exactly sure what about serialization process makes is create new class instances, and not sure how commonplace this is.

However, I don't see a good reason why we should keep references as keys, not strings. Any ideas why switching to strings could backfire?

cc @alexandresablayrolles @karthikprasad

Drxdx commented 1 year ago

The same problem in the First layer conv2d. Resnet18 is work, but the darts model is not work. So I don't know how to fix this problem. Maybe the darts is not compatible with opacus.