Open sandertom opened 2 years ago
By the way, I also noticed this issue. The culprit seems to be that the dataloader isn't split across devices as done by torch.utils.data.distributed.DistributedSampler
. If you look at this function here, there's essentially no special processing in the non-Poisson mode.
I think it's fine to state that under the addition / removal DP def'n, subsampling not based on Poisson would break the entire privacy guarantee, but I think in practice, it's still helpful to be able to run with fixed batches to quickly gauge the performance. In addition, the performance difference between Poisson vs non-Poisson isn't huge when the batch size is reasonably large (which is typical these days; the community is slowly converging to realize that large batches are helpful for DP training).
There's also another replacement-based def'n of DP, in which fixed-size subsampled batches still give a reasonable privacy guarantee, but that's an entirely different story (and relies on different accounting procedures).
Regardless, it seems the library should at least throw an exception for the use case or document it somewhere just in case people bump into it silently and get erroneous results (and end up reporting those in conference submissions!).
Since I've been running exploratory experiments, I'm posting my tentative fix for the issue. For context, I'm trying to do the non-Poisson version in a distributed setting with the algorithmic steps being exactly the same as in the non-Poisson setting. These steps are the common ones to get DDP to work.
if world_size > 1 and not args.poisson_sampling:
from torch.utils.data import DistributedSampler
# TODO: DistributedSampler does not allow passing in generators, this makes setting secure prng impossible.
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=int(args.sample_rate * len(train_dataset)) // world_size,
sampler=DistributedSampler(
train_dataset, drop_last=True, shuffle=True, num_replicas=world_size, rank=device
)
)
train_loader.sampler.set_epoch(epoch)
## 🐛 Bug
To Reproduce
If we set the batch size to 512 in the dataloader (before making it private), with two GPUs, we expect that each GPU will have processed approximately 256 examples between each optimiser step. This is what is happening with Poisson Sampling. Without poisson Sampling, it seems that each GPU will process 512 examples, so a total of 1024 examples between each optimiser step.
If using a BatchMemoryManager, is_updated = not (optimizer._check_skip_next_step(pop_next=False)) can help you find out if you are at the end of an "true" optimizer step. To reproduce the error, you can count for how many examples were processed by each worker between two of these steps.