Describe the bug
Invoking the torch.jit.attach_eia() method on the huggingface transformers roBERTa model results in the following error: RuntimeError: class '__torch__.torch.nn.modules.normalization.___torch_mangle_6.LayerNorm' already defined.
To reproduce
import torch, torcheia
from transformers import RobertaForSequenceClassification
# Load pretrained roBERTa base model
roberta_model = RobertaForSequenceClassification.from_pretrained("roberta-base", torchscript=True)
# Manufacture some input data
attention_mask = torch.Tensor([[1,1,1,1,1,1], [1,1,1,1,0,0]] * 4).long()
input_ids = torch.Tensor([[0,9226,16,10,1296,2], [0,463,277,2,1,1]] * 4).long()
# Validate that the model can run with input data
roberta_model.eval()
roberta_model(input_ids, attention_mask)
# Trace model
traced_model = torch.jit.trace(roberta_model, [attention_mask, input_ids])
# Validate that the traced model can run with input data
traced_model.eval()
traced_model(input_ids, attention_mask)
torch._C._jit_set_profiling_executor(False)
eia_model = torcheia.jit.attach_eia(traced_model, 0)
Expected behavior
I expect the attach_eia method to work correctly with this model.
System information
SageMaker JupyterLab notebook with conda_amazonei_pytorch_latest_p36 environment
Additional context
There's a similar error message associated with this issue in the pytorch project: https://github.com/pytorch/pytorch/issues/29170
It looks like the related solution was merged in before version 1.5.1.
Describe the bug Invoking the
torch.jit.attach_eia()
method on the huggingface transformers roBERTa model results in the following error:RuntimeError: class '__torch__.torch.nn.modules.normalization.___torch_mangle_6.LayerNorm' already defined
.To reproduce
Expected behavior I expect the
attach_eia
method to work correctly with this model.System information
Additional context There's a similar error message associated with this issue in the pytorch project: https://github.com/pytorch/pytorch/issues/29170 It looks like the related solution was merged in before version 1.5.1.