I am trying to implement a graph for convnext, which has a custom module: LayerNorm2d. In there, the input is permuted before sent through LayerNorm. The mergehandler can use handle_layernorm for that right? no need to implement another function if I understand it right. That ties in with the question, if a torch.nn.Permute needs anything else but a "handle_fn" handler
Yes the LayerNorm2d handler should already contain everything you need.
For Permute, I believe it should look very similar to the handle_fn logic, except that you will further have to permute the dimensions of the merge/unmerge according to permutation defined in the module.
I am trying to implement a graph for convnext, which has a custom module: LayerNorm2d. In there, the input is permuted before sent through LayerNorm. The mergehandler can use handle_layernorm for that right? no need to implement another function if I understand it right. That ties in with the question, if a torch.nn.Permute needs anything else but a "handle_fn" handler