pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

per_sample gradient is None but grad is populated #578

Open anirban-nath opened 1 year ago

anirban-nath commented 1 year ago

I have a particular LayerNorm function in my code because of which I am not able to successfully run Opacus in my code. This LayerNorm function function is defined just like 3 - 4 others in my code and is used in 2 places. When I execute loss.backward(), the grad of the layer function is populated but per_sample grad isn't, which leads Opacus to throw the error "Per sample gradient is not initialized. Not updated in backward pass?"

Under what circumstances is this possible?

PS: This is how the norm is defined

decoder_norm = nn.LayerNorm(d_model) self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec)

This is how it is used. The usages are shown with comments beside them

`class TransformerDecoder(nn.Module):

def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
    super().__init__()
    self.layers = _get_clones(decoder_layer, num_layers)
    self.num_layers = num_layers
    self.norm = norm // HERE
    self.return_intermediate = return_intermediate

def forward(self, tgt, memory,
            tgt_mask: Optional[Tensor] = None,
            memory_mask: Optional[Tensor] = None,
            tgt_key_padding_mask: Optional[Tensor] = None,
            memory_key_padding_mask: Optional[Tensor] = None,
            pos: Optional[Tensor] = None,
            query_pos: Optional[Tensor] = None):
    output = tgt

    intermediate = []

    for layer in self.layers:
        output = layer(output, memory, tgt_mask=tgt_mask,
                       memory_mask=memory_mask,
                       tgt_key_padding_mask=tgt_key_padding_mask,
                       memory_key_padding_mask=memory_key_padding_mask,
                       pos=pos, query_pos=query_pos)
        # print(output.shape)
        if self.return_intermediate:
            intermediate.append(self.norm(output)) // HERE

    if self.norm is not None:
        output = self.norm(output // HERE
        if self.return_intermediate:
            intermediate.pop()
            intermediate.append(output)`
alexandresablayrolles commented 1 year ago

Thanks for raising this issue. The reason is that Opacus computes grad_samples using "hooks", so it only works for standard layers. You can pass grad_sample_mode="functorch" to make_private(), which will make Opacus use functorch to automatically compute grad_samples for new layers (it is not guaranteed to work but most of the time it does the job).

anirban-nath commented 1 year ago

Thanks for raising this issue. The reason is that Opacus computes grad_samples using "hooks", so it only works for standard layers. You can pass grad_sample_mode="functorch" to make_private(), which will make Opacus use functorch to automatically compute grad_samples for new layers (it is not guaranteed to work but most of the time it does the job).

Hi. I was using the make_private_with_epsilon function and I tried "functorch" but it did not work.

alexandresablayrolles commented 1 year ago

It should also work with make_private_with_epsilon. Do you still have the same error message?

anirban-nath commented 1 year ago

It should also work with make_private_with_epsilon. Do you still have the same error message?

Exact same error message. No difference. I tried with both make_private and make_private_with_epsilon. I even tried replacing that LayerNorm with a GroupNorm but none of these have made any difference.