pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

The privacy engine make_private_with_epsilon function leads to memory error even for a very small network. #602

Closed telegraphroad closed 7 months ago

telegraphroad commented 10 months ago

🐛 Bug

The privacy engine make_private_with_epsilon function leads to memory error even for a very small network. You can find the full code in this colab. I have tested it on a GPU with 12gb of RAM and it still gives me an out of memory with a few warnings as:


- UserWarning: Optimal order is the largest alpha. Please consider expanding the range of alphas to get a tighter privacy bound.
- /home/user/anaconda3/envs/dpf/lib/python3.11/site-packages/opacus/accountants/analysis/prv/prvs.py:50: RuntimeWarning: invalid value encountered in log  z = np.log((np.exp(t) + q - 1) / q)
- /home/user/anaconda3/envs/dpf/lib/python3.11/site-packages/opacus/accountants/analysis/rdp.py:332: UserWarning: Optimal order is the smallest alpha. Please consider expanding the range of alphas to get a tighter privacy bound.
- /home/user/anaconda3/envs/dpf/lib/python3.11/site-packages/opacus/accountants/analysis/prv/prvs.py:50: RuntimeWarning: overflow encountered in exp  z = np.log((np.exp(t) + q - 1) / q)

On a A40 with 48gb or memory it gave me the following error:

Traceback (most recent call last): File "./normalizing-flows/examples/rnvp_fmnist.py", line 791, in <module> model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/opacus/privacy_engine.py", line 517, in make_private_with_epsilon noise_multiplier=get_noise_multiplier( ^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/opacus/accountants/utils.py", line 70, in get_noise_multiplier eps = accountant.get_epsilon(delta=target_delta, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/opacus/accountants/prv.py", line 97, in get_epsilon dprv = self._get_dprv(eps_error=eps_error, delta_error=delta_error) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/opacus/accountants/prv.py", line 126, in _get_dprv return compose_heterogeneous( ^^^^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/opacus/accountants/analysis/prv/compose.py", line 58, in compose_heterogeneous dprvs = [ ^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/opacus/accountants/analysis/prv/compose.py", line 59, in <listcomp> _compose_fourier(dprv, num_self_composition) File "/miniconda3/envs/nflt/lib/python3.11/site-packages/opacus/accountants/analysis/prv/compose.py", line 14, in _compose_fourier composed_pmf = irfft(rfft(dprv.pmf) ** num_self_composition) ^^^^^^^^^^^^^^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/scipy/fft/_backend.py", line 25, in __ua_function__ return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/miniconda3/envs/nflt/lib/python3.11/site-packages/scipy/fft/_pocketfft/basic.py", line 62, in r2c return pfft.r2c(tmp, (axis,), forward, norm, None, workers) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ MemoryError: std::bad_alloc

Please reproduce using our template Colab and post here the link

To Reproduce

  1. Just run the experiment. You can change the size of the network in model definition. Right now the model is quite shallow but should be much deeper.

Expected behavior

Environment

Please copy and paste the output from our environment collection script (or fill out the checklist below manually).

PyTorch version: 2.0.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 10 (buster) (x86_64)
GCC version: (Debian 8.3.0-6) 8.3.0
Clang version: Could not collect
CMake version: version 3.27.2
Libc version: glibc-2.28

Python version: 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.19.0-17-amd64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: 11.0.194
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla P100-PCIE-12GB
Nvidia driver version: 470.57.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
Address sizes:       46 bits physical, 48 bits virtual
CPU(s):              24
On-line CPU(s) list: 0-23
Thread(s) per core:  2
Core(s) per socket:  6
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Gold 6128 CPU @ 3.40GHz
Stepping:            4
CPU MHz:             3611.077
BogoMIPS:            6800.00
Virtualization:      VT-x
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            19712K
NUMA node0 CPU(s):   0,2,4,6,8,10,12,14,16,18,20,22
NUMA node1 CPU(s):   1,3,5,7,9,11,13,15,17,19,21,23
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin mba tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.0.1
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] numpy                     1.25.2                   pypi_0    pypi
[conda] torch                     2.0.1                    pypi_0    pypi
[conda] torchvision               0.15.2                   pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi
## Additional context
### Tasks
HuanyuZhang commented 8 months ago

I think the problem is associated with "prv" accounting, and this part is independent of the model size. Could you try explicitly setting (accountant = "rdp") when initializing PrivacyEngine?

telegraphroad commented 8 months ago

I think the problem is associated with "prv" accounting, and this part is independent of the model size. Could you try explicitly setting (accountant = "rdp") when initializing PrivacyEngine?

Thanks for the response @HuanyuZhang. Changing the accountant to rdp fixed the previous issue but now it tells me Per sample gradient is not initialized. Not updated in backward pass?. I'm running the validator and fix functions on the model and I get no errors on validation. Isn't fix supposed to add the grad sampler to the model or does it do it only for certain "standard" types of layers and modules?:

from opacus.validators import ModuleValidator

model = NormalizingFlowMNist(num_coupling=2, num_final_coupling=2, planes=2).to(device)

model = ModuleValidator.fix(model)
errors = ModuleValidator.validate(model, strict=True)
print(errors[-5:])

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

DELTA = 0.9/len(train_loader.dataset)

privacy_engine = PrivacyEngine(accountant = "rdp")

model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,

    epochs=epochs,
    target_epsilon=target_epsilon,
    target_delta=DELTA,
    max_grad_norm=MAX_GRAD_NORM,
)
HuanyuZhang commented 8 months ago

Could you kindly let me know which line triggers this error? I previously thought this error only occurred when the model is updated by an optimizer.

telegraphroad commented 8 months ago

@HuanyuZhang Previously, it happened when I ran the make_private_with_epsilon, before the training loop started. After changing the accountant to rdp the problem with make_private_with_epsilon was fixed, but now during training I get the error above. Now it happens on optimizer.step():

[<ipython-input-4-d2edbffee1f2>](https://localhost:8080/#) in train_loop(dataloader, model, loss_fn, optimizer, batch_size, report_iters, num_pixels)
    580 
    581         #prev = [(name, x, x.grad) for name, x in model.named_parameters(recurse=True)]
--> 582         optimizer.step()
    583 
    584         if batch % report_iters == 0:

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in step(self, closure)
    511                 closure()
    512 
--> 513         if self.pre_step():
    514             return self.original_optimizer.step()
    515         else:

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in pre_step(self, closure)
    492                 returns the loss. Optional for most optimizers.
    493         """
--> 494         self.clip_and_accumulate()
    495         if self._check_skip_next_step():
    496             self._is_last_step_skipped = True

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in clip_and_accumulate(self)
    395         """
    396 
--> 397         if len(self.grad_samples[0]) == 0:
    398             # Empty batch
    399             per_sample_clip_factor = torch.zeros((0,))

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in grad_samples(self)
    343         ret = []
    344         for p in self.params:
--> 345             ret.append(self._get_flat_grad_sample(p))
    346         return ret
    347 

[/content/opacus/opacus/optimizers/optimizer.py](https://localhost:8080/#) in _get_flat_grad_sample(self, p)
    280             )
    281         if p.grad_sample is None:
--> 282             raise ValueError(
    283                 "Per sample gradient is not initialized. Not updated in backward pass?"
    284             )

ValueError: Per sample gradient is not initialized. Not updated in backward pass?
HuanyuZhang commented 8 months ago

Thanks. Could you also provide more code, especially on how you do the backward propogation?

HuanyuZhang commented 7 months ago

Did not get reply so I close the issue. Feel free to re-open it if needed.

helin0815 commented 5 months ago

I have also encountered this problem. May I ask if this problem has been resolved. This problem occurs when using slightly larger models, and even adding "rdp" does not work