Closed nnjnjn closed 1 month ago
Hi @nnjnjn , we don't support F.gelu
but you may use the following GELU
module. We will add it to this repo soon, but before that you may use it by copying it to your code.
class GELUOp(torch.autograd.Function):
@staticmethod
def symbolic(g, x):
return g.op('custom::Gelu', x)
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.nn.functional.gelu(x)
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
grad_input = grad_output.clone()
grad = 0.5 * (1 + torch.erf(x / np.sqrt(2))) + x * torch.exp(-0.5 * x ** 2) / np.sqrt(2 * torch.pi)
return grad_input * grad
class GELU(nn.Module):
def forward(self, x):
return GELUOp.apply(x)
I encountered the same issue. Thanks for addressing it.
However, after I copied and used the new GELU class, I have the new error message
ERROR 10:03:01 The node has an unsupported operation: Node(name='/8', ori_name=None, inputs=['/x'], attr={}, op='custom::Gelu', param=OrderedDict(), input_index=None, bound_node=None, output_index=0, perturbation=None) ERROR 10:03:01 The node has an unsupported operation: Node(name='/10', ori_name=None, inputs=['/x.3'], attr={}, op='custom::Gelu', param=OrderedDict(), input_index=None, bound_node=None, output_index=0, perturbation=None) ERROR 10:03:01 Unsupported operations: ERROR 10:03:01 Name: custom::Gelu, Attr: {} ERROR 10:03:01 Name: custom::Gelu, Attr: {}
It seems that the new custon::Gelu operation is still not supported?
@TonghanWang Sorry about that. There is still some code about GeLU (implementation for custom::Gelu
) missing. We will soon release a new version of the code containing GeLU.
@TonghanWang Sorry about that. There is still some code about GeLU (implementation for
custom::Gelu
) missing. We will soon release a new version of the code containing GeLU.
Sorry, I'd like to ask if the implementation of GELU is currently supported?
Yes, it has been included in the latest release.
On Fri, 17 May 2024 at 23:40, nninjn @.***> wrote:
@TonghanWang https://github.com/TonghanWang Sorry about that. There is still some code about GeLU (implementation for custom::Gelu) missing. We will soon release a new version of the code containing GeLU.
Sorry, I'd like to ask if the implementation of GELU is currently supported?
— Reply to this email directly, view it on GitHub https://github.com/Verified-Intelligence/auto_LiRPA/issues/64#issuecomment-2118665145, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACUBARB2PM5WTF7YTOF4TADZC3ZUNAVCNFSM6AAAAABCNEICOOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJYGY3DKMJUGU . You are receiving this because you commented.Message ID: @.***>
Yes, it has been included in the latest release. … On Fri, 17 May 2024 at 23:40, nninjn @.> wrote: @TonghanWang https://github.com/TonghanWang Sorry about that. There is still some code about GeLU (implementation for custom::Gelu) missing. We will soon release a new version of the code containing GeLU. Sorry, I'd like to ask if the implementation of GELU is currently supported? — Reply to this email directly, view it on GitHub <#64 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACUBARB2PM5WTF7YTOF4TADZC3ZUNAVCNFSM6AAAAABCNEICOOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJYGY3DKMJUGU . You are receiving this because you commented.Message ID: @.>
Thank you! I'd like to know if there is something wrong with my implementation because it still indicates that there are unsupported operations.
My codes:
import torch
from auto_LiRPA import BoundedModule
import torch.nn.functional as F
import torch.nn as nn
class FNN(nn.Module):
def __init__(self):
super(FNN, self).__init__()
self.fc1 = nn.Linear(2, 4)
self.act = nn.GELU()
self.fc2 = nn.Linear(4, 2)
def forward(self, x):
x = x.float()
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
model_ori = FNN()
batch_size = 100
x = torch.normal(0, 0.2, (batch_size, 2))
dummy_input = x[:2]
model = BoundedModule(model_ori, dummy_input)
Error:
ERROR 14:38:10 The node has an unsupported operation: Node(name='/9', ori_name=None, inputs=['/8'], attr={}, op='onnx::Erf', param=OrderedDict(), input_index=None, bound_node=None, output_index=0, perturbation=None)
ERROR 14:38:10 Unsupported operations:
ERROR 14:38:10 Name: onnx::Erf, Attr: {}
Traceback (most recent call last):
File "/data/majianan/ProvRepair/gelu.py", line 25, in <module>
model = BoundedModule(model_ori, dummy_input)
File "/home/hdu/miniconda3/envs/repair/lib/python3.9/site-packages/auto_LiRPA/bound_general.py", line 130, in __init__
self._convert(model, global_input)
File "/home/hdu/miniconda3/envs/repair/lib/python3.9/site-packages/auto_LiRPA/bound_general.py", line 848, in _convert
nodesOP, nodesIn, nodesOut, template = self._convert_nodes(
File "/home/hdu/miniconda3/envs/repair/lib/python3.9/site-packages/auto_LiRPA/bound_general.py", line 740, in _convert_nodes
raise NotImplementedError('There are unsupported operations')
NotImplementedError: There are unsupported operations
Hi @nninjn , you still need to use our customized GeLU module instead of nn.GeLU: https://github.com/Verified-Intelligence/auto_LiRPA/blob/master/auto_LiRPA/operators/gelu.py#L420
Hi @nninjn , you still need to use our customized GeLU module instead of nn.GeLU: https://github.com/Verified-Intelligence/auto_LiRPA/blob/master/auto_LiRPA/operators/gelu.py#L420
Thank you for your help, it works now. By the way, I noticed that compared to activation functions like ReLU and Sigmoid, the bound computation for GELU seems to take significantly more time (in my toy example). Is this due to the inherent complexity of GELU, or is there something I haven't set up correctly?
Yes, the implementation of linear relaxation for GeLU is more complicated as the function itself is also more complicated than ReLU/Sigmoid (introduced in this paper https://openreview.net/forum?id=ivokwVKY4o https://openreview.net/forum?id=ivokwVKY4o).
On May 18, 2024, at 12:19 AM, nninjn @.***> wrote:
Hi @nninjn https://github.com/nninjn , you still need to use our customized GeLU module instead of nn.GeLU: https://github.com/Verified-Intelligence/auto_LiRPA/blob/master/auto_LiRPA/operators/gelu.py#L420 https://github.com/Verified-Intelligence/auto_LiRPA/blob/master/auto_LiRPA/operators/gelu.py#L420 Thank you for your help, it works now. By the way, I noticed that compared to activation functions like ReLU and Sigmoid, the bound computation for GELU seems to take significantly more time (in my toy example). Is this due to the inherent complexity of GELU, or is there something I haven't set up correctly?
— Reply to this email directly, view it on GitHub https://github.com/Verified-Intelligence/auto_LiRPA/issues/64#issuecomment-2118674608, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACUBARF3B3TETUE2AQFJGWTZC36HTAVCNFSM6AAAAABCNEICOOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMJYGY3TINRQHA. You are receiving this because you commented.
I want to use auto_LiRPA to compute the boundary of a neural network with GELU activation function. Here is my code:
And error: