Issue:
In class INSTRUCTOR_Transformer, inside def forward(), the attention mask corresponding to the instruction tokens are
set to 0 in the following manner:
if context_masks is not None:
import torch
assert len(context_masks) == len(attention_mask)
n = len(attention_mask)
# print('n ',n)
for local_idx in range(n):
assert torch.sum(attention_mask[local_idx]).item() >= context_masks[local_idx].item(),\
f'{attention_mask[local_idx]}, {context_masks[local_idx]}, ' \
f'{torch.sum(attention_mask[local_idx]).item()}, {context_masks[local_idx].item()}'
attention_mask[local_idx][:context_masks[local_idx]] = 0
I want to draw attention to the line n = len(attention_mask). This int variable will be treated as a constant during onnx conversion, which will lead to incorrect inference when the instruction token length changes.
Solution:
Instead of geting the instruction token length and manually iterating the attention_mask to set the value as 0,
I have introduced def prepare_input_features() function under class Instructor that carries out the same task using tensor manipulations.
By this way performing inference using the onnx model works as expected for any instruction.
Other change set:
There are many other diff in the pull request, those are a result of adhering to formatting/linting standards.
Issue: In
class INSTRUCTOR_Transformer
, insidedef forward()
, the attention mask corresponding to the instruction tokens are set to 0 in the following manner:I want to draw attention to the line
n = len(attention_mask)
. This int variable will be treated as a constant during onnx conversion, which will lead to incorrect inference when the instruction token length changes.Solution: Instead of geting the instruction token length and manually iterating the
attention_mask
to set the value as 0, I have introduceddef prepare_input_features()
function underclass Instructor
that carries out the same task using tensor manipulations. By this way performing inference using the onnx model works as expected for any instruction.Other change set: There are many other diff in the pull request, those are a result of adhering to formatting/linting standards.