pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.65k stars 328 forks source link

Making a custom transformer architecture work with opacus #644

Open nhianK opened 3 months ago

nhianK commented 3 months ago

I am trying to make an architecture work with opacus . It consists of two encoders that use Self-attention and produces context embeddings x_t and y_t. “Knowledge Retriever” is using masked attention. I suppose there are a few issues with this. It uses a modified multihead attention that uses an exponential decay function applied to the scaled dot product and a distance adjustment factor gamma that requires no gradient. It uses the model parameters that has been already calculated to obtain the distance adjustments. This causes conflicts with opacus for which I will create a separate issue later. For simplicity, I have used just multihead attention to avoid conflicts with opacus. Here is the notebook that can be used to reproduce this: https://colab.research.google.com/drive/1Sp3jILzB3HvizIAw3OTiQGnVq7LB5gee?usp=sharing

And this still produces the following error:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1352: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior. warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "


ValueError Traceback (most recent call last)

in () ----> 1 best_epoch = train_one_dataset(train_q_data, train_qa_data, train_pid, valid_q_data, valid_qa_data, valid_pid) 2 5 frames in train_one_dataset(train_q_data, train_qa_data, train_pid, valid_q_data, valid_qa_data, valid_pid) 37 for idx in range(max_iter): 38 # Train Model ---> 39 train_loss, train_accuracy, train_auc = train( 40 dp_model, dp_optimizer, train_q_data, train_qa_data, train_pid, accountant, label='Train') 41 # Validation step in train(net, optimizer, q_data, qa_data, pid_data, accountant, label) 89 net.parameters(), max_norm=maxgradnorm) 90 ---> 91 optimizer.step() 92 93 # correct: 1.0; wrong 0.0; padding -1.0 /usr/local/lib/python3.10/dist-packages/opacus/optimizers/optimizer.py in step(self, closure) 516 closure() 517 --> 518 if self.pre_step(): 519 return self.original_optimizer.step() 520 else: /usr/local/lib/python3.10/dist-packages/opacus/optimizers/optimizer.py in pre_step(self, closure) 494 # The corner case when the optimizer has no trainable parameters. 495 # Essentially the DPOptimizer act as a normal optimizer --> 496 if self.grad_samples is None or len(self.grad_samples) == 0: 497 return True 498 /usr/local/lib/python3.10/dist-packages/opacus/optimizers/optimizer.py in grad_samples(self) 343 ret = [] 344 for p in self.params: --> 345 ret.append(self._get_flat_grad_sample(p)) 346 return ret 347 /usr/local/lib/python3.10/dist-packages/opacus/optimizers/optimizer.py in _get_flat_grad_sample(self, p) 280 ) 281 if p.grad_sample is None: --> 282 raise ValueError( 283 "Per sample gradient is not initialized. Not updated in backward pass?" 284 ) ValueError: Per sample gradient is not initialized. Not updated in backward pass? There is also some behavior I had to note. In the architecture class, transformer layers are initialized. In the forward pass the x and y embeddings are passed into the encoders. The flag is there to ensure when the knowledge retriever block(masked attention) is executed. This is clearer in the forward pass of the transformer layer, where the “if statement block” is for the masked attention (knowledge retriever) and the “else block” corresponds to the encoders on the left( see picture in notebook). All three components use the same forward pass.( see forward calls of Architecture, Transformer Layer classes) Training/ optimizer step only seems to execute when I leave out the if/else conditions and have one forward pass for all three parts of the model: two encoders and knowledge retriever that uses masked attention. Is there a way around this? Is there a way this could be reimplemented in a way which would allow per sample gradient computation? Notebook without opacus: https://colab.research.google.com/drive/1jg-ygK7Vfou-IaJqNaujk-CHfdMruup3?usp=sharing
HuanyuZhang commented 3 months ago

A quick question/guess: is there any model parameter which has several (X) forward passes but has <X backward passes? For those parameters, the per_sample_grad will not be appropriately calculated/stored which might lead to this issue.

nhianK commented 3 months ago

Thank you for responding. I am not sure what you meant by x forward passes and less than x backward passes. Could you give me a reproducible general case like that? But here are some relevant details about the architecture that could help with understanding the issue. Referring to the architecture in the picture: Input data labeled Rasch embedings are processed by question and knowledge encoders respectively to obtain xhat_t and yhat_t (contextualized embedings, both use self-attention). In the knowledge encoder(masked attention) xhat_t is used for key, query and yhat_t for value. The way it is implemented in the code is that the inputs are passed sequentially through the transformer blocks. There was initially one transformer class defined for all three components and it had all the necessary conditional logic and flags. So one thing I did was remove the operations nested in the if-else statements that were meant for the knowledge retriever and knowledge encoder because it was causing issues. These conditional statements dictated when masks or certain layers would be applied. After removal, the original features of the model was still preserved. However, one component still remains difficult to determine- the Transformer layer had one condition. This condition was only meant for the knowledge retriever and knowledge encoder as shown in the illustration, where linear layers and activation is applied. Question encoder is an exception. if not question encoder: query2 = self.linear2(self.dropout( self.activation(self.linear1(query)))) query = query + self.dropout2((query2)) query = self.layer_norm2(query) Here are the notebooks reproducing the architecture with and without opacus.

see how model does without opacus(it runs) : https://colab.research.google.com/drive/1CjPdzUaThLKrY0vVUUM-__zrLgMsFxe7?usp=sharing model with opacus : https://colab.research.google.com/drive/1D0TwshmEzhc3_ymKo9PQPXFATyvFMWhM?usp=sharing

I defined all the encoders, knowledge retrievers separately because I needed to eliminate problems with conditional statements/computation paths. So one behavior I observed is with opacus, the layers above need to be applied to all three forward computation paths or the model cannot compute per sample gradients. Question encoder cant be an exception. (see class Question encoder in the second notebook and my comment there) So my question, why do you think I am having this strange behavior. I am trying to get it to run without question encoder having linear layers. In other words, why does the absence of these layers in Question encoder prevent per sample gradient computation?. In the picture, I annotated where the layers should or should not be in the original architecture. akt-ann

HuanyuZhang commented 2 months ago

Thanks for your detailed response. By any chance could we experiment with one single blocker at a time (for example, Question encoder) to see whether the problem replicate?

Specifically, given x to be the output of Question encoder, just define some dummy label y and any loss function L, and do the backward pass of L(x,y). Then we can see whether per_sample_grad is empty or not.