awslabs / fast-differential-privacy

Fast, memory-efficient, scalable optimization of deep learning with differential privacy
Apache License 2.0
91 stars 15 forks source link

Support for llama and another custom module? #36

Closed alidadsetan closed 1 month ago

alidadsetan commented 2 months ago

Hello again. I am using PrivacyEngine_Distributed_extending to do private training on a llama2 model. As part of my model, I am using a custom module, a "label attention" block, as defined in here. The label attention has three linear layers, but uses the @ operator, and tensor.mul operators, as part of the forward pass when using the third linear layer parameters.

When training without privacy, I get descent performance on my training data, but the private learning basically fails having very high training loss. I am wondering if I am doing everything correctly, or if I have to extend the library in order to do my application. I am suspect that maybe I have to change both fastDP/transformers_support.py and fastDP/supported_differentially_private_layers.py. Can you please clarify if my guess is right and I have to modify the library in some way, and what are the required modifications? Also, I would appreciate any pointers on what to change and how to do it.

alidadsetan commented 1 month ago

Hello again, I did some small changes so the pytorch extending privacy engine supports the llama rsmnorm layer, I will post them here to check it with you.

@torch.jit.script
def _rsm_weight_grad_smaple(input, grad_weight, max_grad_norm: float):
    grad_norm2 = grad_weight.norm(2, dim=-1)**2
    clip_factor = 1/(torch.sqrt(grad_norm2.to(input))+1e-4)
    grad_weight = torch.einsum('B,B...->...',clip_factor,grad_weight)
    return grad_weight
 class RSMWeightDPFunction(Function):
    @staticmethod
    def forward(ctx, input, weight):
        ctx.save_for_backward(input, weight)
        return weight * input

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors

        grad_input = grad_weight = None

        if ctx.needs_input_grad[1]:
            grad_weight=torch.einsum('B...d,B...d->Bd', input, grad_output)
            grad_weight = _rsm_weight_grad_smaple(input, grad_weight, weight.max_grad_norm)
            grad_weight = grad_weight/math.sqrt(weight.n_layers) + \
                + torch.randn_like(weight, device= weight.device, dtype= weight.dtype)* weight.noise

        if ctx.needs_input_grad[0]:
            grad_input = torch.einsum('d,B...d->B...d', weight, grad_output)

        return grad_input, grad_weight, None
class DPLlamaRMSNorm(Module):
    def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None):
        super().__init__()
        self.weight = Parameter(torch.ones(hidden_size, device=device, dtype=dtype))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return RSMWeightDPFunction.apply(hidden_states.to(input_dtype), self.weight)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
def replace_llama_rsmnorm(module):
    for layer_str in dir(module):
        layer = getattr(module, layer_str)
        if type(layer) == transformers.models.llama.modeling_llama.LlamaRMSNorm and requires_grad(layer):
            new_layer = DPLlamaRMSNorm(
                layer.weight.data.shape[0],layer.variance_epsilon, device=layer.weight.device, dtype=layer.weight.dtype)
            new_layer.weight = layer.weight
            del layer

            setattr(module, layer_str, new_layer)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()

    if hasattr(module,'children'):
        for immediate_child_module in module.children():
            replace_llama_rsmnorm(immediate_child_module)

and adding this to the pytorch extending:

replace_Embedding(module)
replace_Linear(module)
replace_Conv2d(module)
replace_LayerNorm(module)
replace_GroupNorm(module)
replace_transformersConv1D(module)
replace_llama_rsmnorm(module)

If that looks correct to you guys, I can do a pull request.

woodyx218 commented 1 month ago

Hello again. I am using PrivacyEngine_Distributed_extending to do private training on a llama2 model. As part of my model, I am using a custom module, a "label attention" block, as defined in here. The label attention has three linear layers, but uses the @ operator, and tensor.mul operators, as part of the forward pass when using the third linear layer parameters.

When training without privacy, I get descent performance on my training data, but the private learning basically fails having very high training loss. I am wondering if I am doing everything correctly, or if I have to extend the library in order to do my application. I am suspect that maybe I have to change both fastDP/transformers_support.py and fastDP/supported_differentially_private_layers.py. Can you please clarify if my guess is right and I have to modify the library in some way, and what are the required modifications? Also, I would appreciate any pointers on what to change and how to do it.

Hi, as a rule of thumb, you need to change fastDP/supported_differentially_private_layers.py for new modules that are not currently supported in this library. A possibly easier alternative is to replace your label attention with torch.Linear without the @ operator and then let the privacy engine privatize it. Another alternative is to freeze any non-supported layers.

woodyx218 commented 1 month ago

Thanks for sharing llama rsmnorm. It looks correct to me but I am curious does using DP RMSNorm give better result than simply freezing these layers?

alidadsetan commented 1 month ago

I got some improvements but I changed a lot of things and I am not sure if this contributed or not. I was just thinking maybe having support for all of the llama layers is nice because it is a popular model.

woodyx218 commented 1 month ago

Thank you! Please do a PR and I will review it.