awslabs / fast-differential-privacy

Fast, memory-efficient, scalable optimization of deep learning with differential privacy
Apache License 2.0
83 stars 11 forks source link

Convert to float32 activations and backprops to allow training in mixed precision #19

Closed lccnl closed 2 months ago

lccnl commented 7 months ago

Issue #, if available:

To train in Mixed precision for example with torch.autocast, we need to convert the backprops and the activations in float32 as some layers work with float16 and other with float32

Description of changes:

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

anukaal commented 6 months ago

Looks clear

kiddyboots216 commented 6 months ago

There is an issue with the promotion. Consider using fastDP with autocast for a model with an embedding layer. Then, this will be the output of running a debugger starting at line 132;

(Pdb) !common_type
torch.float16
(Pdb) !layer.activations.dtype
torch.int64
(Pdb) !B.dtype
torch.float16

When we promote activations.dtype to float16, this will throw an error in _clip_embedding_grad;

> /scratch/gpfs/ashwinee/fast-differential-privacy/fastDP/supported_layers_grad_samplers.py(311)_clip_embedding_grad()
-> A = F.one_hot(A, num_classes=layer.weight.shape[0]).to(B)  # (batch_size, seq_len, vocab_dim,)
(Pdb) n
RuntimeError: one_hot is only applicable to index tensor.
kiddyboots216 commented 6 months ago

One (extremely bad) way to fix this is just insert a check for nn.Embedding and if so skip the promotion. But probably something better would be to make the promotion specifically depend on float32/float16? I imagine this gets more and more complex with int8 and stuff.

woodyx218 commented 2 months ago

I am archiving this PR as I confirm the incompatibility with embedding layers (see https://github.com/awslabs/fast-differential-privacy/pull/19#issuecomment-1977603420). I also note that the change for distributed learning is more complicated and needs to rewrite some parts of supported_differentially_private_layers.py.

woodyx218 commented 2 months ago

One (extremely bad) way to fix this is just insert a check for nn.Embedding and if so skip the promotion. But probably something better would be to make the promotion specifically depend on float32/float16? I imagine this gets more and more complex with int8 and stuff.

I successfully tested this walk-around. But this is not a formal solution and thus torch.amp is not directly supported by our code.