xlang-ai / instructor-embedding

[ACL 2023] One Embedder, Any Task: Instruction-Finetuned Text Embeddings
Apache License 2.0
1.87k stars 135 forks source link

Modifed masking before pooling - Fixes issue in ONNX conversion #92

Closed ashokrajab closed 7 months ago

ashokrajab commented 1 year ago

Issue: In class INSTRUCTOR_Transformer, inside def forward(), the attention mask corresponding to the instruction tokens are set to 0 in the following manner:

if context_masks is not None:
            import torch
            assert len(context_masks) == len(attention_mask)
            n = len(attention_mask)
            # print('n ',n)
            for local_idx in range(n):
                assert torch.sum(attention_mask[local_idx]).item() >= context_masks[local_idx].item(),\
                    f'{attention_mask[local_idx]}, {context_masks[local_idx]}, ' \
                    f'{torch.sum(attention_mask[local_idx]).item()}, {context_masks[local_idx].item()}'
                attention_mask[local_idx][:context_masks[local_idx]] = 0

I want to draw attention to the line n = len(attention_mask). This int variable will be treated as a constant during onnx conversion, which will lead to incorrect inference when the instruction token length changes.

Solution: Instead of geting the instruction token length and manually iterating the attention_mask to set the value as 0, I have introduced def prepare_input_features() function under class Instructor that carries out the same task using tensor manipulations. By this way performing inference using the onnx model works as expected for any instruction.

Other change set: There are many other diff in the pull request, those are a result of adhering to formatting/linting standards.

ashokrajab commented 1 year ago

@Harry-hash @hongjin-su Looking forward to your inputs...

ashokrajab commented 1 year ago

@hongjin-su @Harry-hash Just a gentle reminder..

ashokrajab commented 11 months ago

Following up on this.

ashokrajab commented 8 months ago

@hongjin-su @Harry-hash just a reminder