Open wigwit opened 1 year ago
And the inconsistency is giving the following test case error messages
FAILED tests/pipelines/test_pipelines_token_classification.py::TokenClassificationPipelineTests::test_sliding_window - TypeError: 'BatchEncoding' object is not an iterator
TokenClassificatioPipeline
has been modified so that the preprocess
function could handle stride
and if stride
is passed in, the tokens will be aggregated again in aggregation_overlapping_entities
function here.
I have tried a few things to see if this problem can be overcome. Here is what I have done:
preprocess
function since the parent class can already handle stride
def __init__(self, window_length: Optional[int] = None, stride: Optional[int] = None, *args, **kwargs):
self.window_length = window_length or 512
if stride is None:
self.stride = self.window_length // 2
elif stride == 0:
self.stride = self.window_length
elif 0 < stride <= self.window_length:
self.stride = stride
else:
raise ValueError("`stride` must be a positive integer no greater " "than `window_length`")
super().__init__(stride=self.stride, *args, **kwargs)
Here is the result test error in the `postprocess` function, it seems like `TokenClassificationPipeline` also change the returned model output when `stride` is passed in. Further investigation is required on this.
```python
FAILED tests/pipelines/test_pipelines_token_classification.py::TokenClassificationPipelineTests::test_sliding_window - TypeError: list indices must be integers or slices, not str
After our branch get rebased, the test case failed because the updated
TokenClassificationPipeline
has addedstride
as an optional arguments that create conflicts with how we are definingstride
.Currently we are defining the stride here, where the following relation holds if
stride
isNone
. And the pipeline will process this task using a sliding window from the definedstride
andwindow_length
https://github.com/connor-qingxia/transformers/blob/a8fc5aa6ab7b941d961d512ae3e9bdf3b5f99e8c/src/transformers/pipelines/token_classification.py#L578 This creates a conflict with howstride
argument is processed here. Ifstride
isNone
, the pipeline would simply process this task as a usual token classification task.