pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.67k stars 332 forks source link

Length of BatchSplittingSampler with Poisson sampling #516

Open s-zanella opened 1 year ago

s-zanella commented 1 year ago

šŸ› Bug

The __len__() method of a BatchSplittingSampler that wraps a DPDataLoader is meant to return the number of physical (as opposed to logical) batches in its iterator. Because Poisson sampling produces variable length logical batches, this length is necessarily approximate and will vary between runs. However, the approximation implemented in __len__() is inaccurate:

expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))

The actual expected number of physical batches per logical batch is:

$$ \sum_{k=1}^\infty k\ \left( F(k\ m) - F((k - 1)\ m) \right) $$

where $m$ is the maximum physical batch size self.max_batch_size and $F$ is the CDF of the binomial distribution with self.sampler.num_samples trials and self.sampler.sample_rate success probability.

This can be approximated as e.g.,

from scipy.stats import binom

def F(k):
    return binom(self.sampler.num_samples, self.sampler.sample_rate).cdf(k * self.max_batch_size) - \
           binom(self.sampler.num_samples, self.sampler.sample_rate).cdf((k - 1) * self.max_batch_size)

expected_physical_batches = int(self.sampler.num_samples * self.sampler.sample_rate / self.max_batch_size)

return int(
    len(self.sampler) *
    sum([i * F(i) for i in range(expected_physical_batches - 4, expected_physical_batches + 4)])
)

Please reproduce using our template Colab and post here the link

Here's a notebook built from the Colab template showing the discrepancy between computed and actual lengths: https://gist.github.com/s-zanella/b70308db3d6d1b1bf15a5a2c8a1cc525

Expected behavior

It's unclear what is the desired behavior. The length approximation currently implemented is clearly incorrect, but a better approximation doesn't help much because the length of a BatchSplittingSampler with Poisson sampling is not fixed. It would be nice to at least warn that the returned length is approximative.

From a user point of view, if BatchMemoryManager is to be a transparent abstraction, I do not care so much about the number of physical batches processed, but about the number of logical batches. The current abstraction does not signal the beginning/end of logical batches, which makes it hard (impossible without code introspection?) to keep track of the number of logical batches processed so far. Having a mechanism to signal the beginning/end of a logical a batch would solve this issue.

pierrestock commented 1 year ago

Hey s-zanella,

Thanks for your interest and for the well-documented issue. Based on my understanding and your notebook, I deduce that this does not influence the accounting (which is correct) but only components that wish to access the BatchSplittingSampler length (in your example, 6249 currently versus correct value of 7607), like tqdm for progress bar.

Labelling this as enhancement. We can take a 2-step plan here:

  1. Add a warning to BatchSplittingSampler as you suggest to make clear that returned length is approximative
  2. Add a mechanism to track the number of logical batches processed.

What would be your thoughts?

Pierre

s-zanella commented 1 year ago

Thanks @pierrestock for looking into this.

I agree with you that this issue doesn't affect the privacy accounting or the correctness of the training process. I'm not aware of any use case that would rely critically on the approximation currently returned by the __len__() method. But I still would consider this as a bug rather as an enhancement since the approximation currently returned by the __len__() method is incorrect.

The 2-step plan you proposed sounds good, except that for the first point I believe that for Poisson sampling, __len__() should either be fixed to return a better approximation and produce a warning message, or just left unimplemented and raise an exception.

dwahdany commented 3 months ago

I'm not aware of any use case that would rely critically on the approximation currently returned by the len() method

I think PyTorch Lightning will actually stop once len is reached. (Thatā€™s why I opened and fixed #640, because if you have signal skip and then never execute the last batch, thereā€™s no real optimizer step occurring at all!)

Isnā€™t it possible to look at the PRNG and determine the actual length beforehand?

s-zanella commented 3 months ago

The fix for #640 in https://github.com/pytorch/opacus/pull/641 is incorrect. The new calculation using math.ceil instead of int can still underapproximate the actual length. The same example notebook I provided above shows that this can happen with high probability.

Computing the actual length beforehand is technically possible by pre-sampling all masks.

dwahdany commented 3 months ago

The fix for #640 in #641 is incorrect. The new calculation using math.ceil instead of int can still underapproximate the actual length. The same example notebook I provided above shows that this can happen with high probability.

I did full-batch training and #641 fixes the length if you just use BatchMemoryManager and don't use subsampling, so I don't appreciate calling it flat out incorrect, since it doesn't claim to fix this issue. I just commented to mention

s-zanella commented 3 months ago

The title of #640 is BatchSplittingSampler return wrong length and you claimed that #641 [f]ixes #640 by ceiling the number of batches, with no qualifiers. Now you are rolling back that claim and saying that it only applies when not using Poisson sampling.

The fix #641 modifies the length computation when using a generic torch.utils.data.BatchSampler and when using Opacus' UniformWithReplacementSampler. I'm taking issue with #641 fixing the latter case. Judging by your last message, I believe that we agree that it does not and that with the fix, BatchSplittingSampler.__len__() still returns a bad approximation. If Pytorch Lightning relies on __len__() to determine when a logical batch finishes, it would work erratically when using Poisson sampling.

I already agreed with you that it's possible to compute the length precisely (but not deterministically) by pre-sampling all batches.