flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.92k stars 2.1k forks source link

[Question]: Controlling the CUDA OOM error for SequenceTagger via the use of tokens threshold #3338

Open adhadse opened 1 year ago

adhadse commented 1 year ago

Question

While trying to deploy the NER model, often the length of the generated token sequence can not be controlled even though the characters can be limited.

Although, It didn't played nice, Even though the total tokens were fixed amount,

I'll write down the piece of code that I tried to work with:

class NERExtractionError(Exception):
    def __init__(self, error, articles_tokens_count=0):
        super().__init__(error)
        self.name = "NERExtractionError"
        self.error = error
        self.batch_tokens_count = articles_tokens_count

class NERExtractionWithAutoBatching:
    def __init__(self, expected_total_tokens_in_batch=35_000):
        self.articles_tokens_count: list[int] = []
        self.expected_total_tokens_in_batch = expected_total_tokens_in_batch

    def generate_batches_of_n_token_lengths(
            self, df: pd.DataFrame
        ):
            batches = []
            df["tokens_count_cumsum"] = df["token_count"].cumsum(axis="index")

            while df.shape[0] != 0:
                if df.iloc[0]["tokens_count_cumsum"] > self.expected_total_tokens_in_batch:
                    batch = df.iloc[[0]]
                    # print(f"first if condition {type(batch)}")
                    batches.append(batch)
                    df = df.drop(index=batch.index)
                    df["tokens_count_cumsum"] = df["token_count"].cumsum(axis="index")
                else:
                    batch = df[df["tokens_count_cumsum"] <= self.expected_total_tokens_in_batch]
                    # print(f"second if condition {type(batch)}")
                    batches.append(batch)
                    df = df.drop(index=batch.index)
                    df["tokens_count_cumsum"] = df["token_count"].cumsum(axis="index")
            return batches

    def inference(self, articles: pd.DataFrame):
        total_tokens_in_each_article: list[int] = articles["token_count"].tolist()
        sentences = articles["sentence"].tolist()

        flair_model = SequenceTagger.load("flair/ner-english-ontonotes-large")
        all_articles_ner = []

        try:
            flair_model.predict(
                sentences, mini_batch_size=len(sentences),
                embedding_storage_mode=None
            )
        except Exception as e:
            del sentences
            torch.cuda.empty_cache()
            raise NERExtractionError(
                e,
                articles_tokens_count=total_tokens_in_each_article
            )

        # each sentence is an article instance
        for sentence in sentences:
            article_ner = []
            for entity in sentence.get_labels('ner'):
                ner_text_and_label = {
                    "text": entity.data_point.text,
                    "labels": entity.to_dict()
                }
                article_ner.append(ner_text_and_label)
            all_articles_ner.append(article_ner)
            sentence.clear_embeddings()

        torch.cuda.empty_cache()

        return all_articles_ner

    def flair_inference(self, df):
        """Main function, pass it a pd.DataFrame containing "inputs" col filled with hundreds and thousands of characters
        CHARACTERS ARE CAPPED AT 20_000
        """
        flair_output = []

        df["sentence"] = df["inputs"].apply(lambda x: Sentence(x))
        df["token_count"] = df["sentence"].apply(lambda x: len(x.tokens))

        articles_batches = []
        batches = self.generate_batches_of_n_token_lengths(df)

        for i, batch in enumerate(batches):
            print(f"[INFO] Batch #{i} token count:"
                  f"{humanize.intcomma(sum(batch['token_count'].tolist()))} "
                  f"of dynamic batch size of {len(batch)}")
            output = self.inference(batch)
            flair_output.extend(output)

        return flair_output

Is this a valid path to take, or are there any better option out there to work with humongous amount of data to gets it's NER.

helpmefindaname commented 1 year ago

Hi @adhadse I think there is no definitive guarantee to never go OOM, however you could use a sentence splitter to split your 20k-character paragraphs into smaller sentences. Then the smaller sentences should be easier to process using batch inference

adhadse commented 1 year ago

@helpmefindaname Thanks for the response. With that I have no idea how would I keep track of hundreds of articles, cause those few batch of articles will result in thousands of Sentence but it isn't possible to trace them back to the article.

I guess the only approach with the sentence splitter would be to iterate over each article, create batch of Sentences, this obviously would be expected to have lesser chance of OOM error, but might reduce the speed of inference (I guess trade-off).

Please correct me if I'm wrong. By the way, why can't tokens be used as a measure to control OOM occurences.

helpmefindaname commented 1 year ago

For tracking, you can use metadata:

sentences = []
for article_id, article_text in enumerate(articles):
   for sentence_id, sentence in enumerate(sentence_splitter.split(article_text)):
      sentence.add_metadata("article_id", article_id)
      sentence.add_metadata("sentence_id", sentence_id)

and then later use sentence.get_metadata("article_id") and sentence.get_metadata("sentence_id") to trace the sentence back to the origin. If you only need the extracted entities, this is very simple. If you also want to also have the position of the entity within the article, you need to add up the entity positions like described in https://github.com/flairNLP/flair/issues/3322#issuecomment-1742687899

Since the model you are using uses transformers embeddings which use the attention mechanism with a O(n²) complexity with respect to the sequence length, you should gain speed by using more smaller sentences instead of less larger ones. Keep in mind that the amount of tokens to compute stays about the same, so there is no reason to assume speed reduction.

About the tokens: As said before, you are using a model that uses transformers embeddings. Those embeddings use a embedding-specific sub-token tokenization and compute on that level. So for example when you use german electra-base, and want to compute the text donaudampfschifffahrtselektrizitätenhauptbetriebswerkbauunterbeamtengesellschaft, you will see 1 token, but the tokenizer of gelectra will compute embeddings for the following subtokens: ['don', '##aud', '##ampf', '##schiff', '##fahrts', '##elekt', '##riz', '##itäten', '##haupt', '##betriebs', '##werk', '##bau', '##unter', '##beamten', '##gesellschaft']. The largest memory footprint is during that computation, hence the 15 sub-tokens are more important than the final embedding for 1 token.

You might be able to assume the amount of subtokens, like on average 1 token has 3 subtokens, but you won't guarantee cases where you will have many more subtokens instead.

adhadseKavida commented 1 year ago

I am counting tokens based on Sentence which has tokenized the article:

df["sentence"] = df["inputs"].apply(lambda x: Sentence(x))
df["token_count"] = df["sentence"].apply(lambda x: len(x.tokens))