xenova / transformers.js

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
https://huggingface.co/docs/transformers.js
Apache License 2.0
10.86k stars 657 forks source link

Is 'aggregation_strategy' parameter available for token classification pipeline? #633

Open boat-p opened 6 months ago

boat-p commented 6 months ago

Question

Hi. I have question.

From HuggingFace Transformers documentation, they have 'aggregation_strategy' parameter in token classification pipeline. Link Need to know in this library provide this parameter?

Thanks.

xenova commented 4 months ago

Hi there 👋 This isn't yet supported, but if a community member is interested in contributing, I'd be happy to add it!

kallebysantos commented 3 months ago

Hi I had implemented some helper functions based on the original tranformers aggregation functions to do simple and max aggregation strategies:

You can adapt it to do other kinds of aggregation, but in my point of view the max is the most usefull one.

Code implementation ```typescript type AggregationGroup = { score: number; entityGroup: string; tokens: TokenClassificationSingle[]; word: string; start: number; end: number; }; function maxAggregation(tokens: TokenClassificationOutput) { const grouped: AggregationGroup[] = tokens .filter((token) => !token.entity.startsWith("O")) .reduce((groups, current) => { const isBeginToken = current.entity.startsWith("B-"); const isWordFragment = current.word.startsWith("##"); if (isBeginToken && !isWordFragment) { return [ ...groups, { score: current.score, entityGroup: current.entity.replace("B-", ""), word: current.word, tokens: [current], // It should be current.start but 'start' is useless in transformers.js start: current.index, end: current.index, } satisfies AggregationGroup, ]; } const lastEntry = groups.pop(); if (!lastEntry) { return groups; } // Discard if index distance is too far: n+1 if (lastEntry.end + 1 < current.index) { const result = [...groups, lastEntry]; if (isBeginToken) { result.push({ score: current.score, entityGroup: current.entity.replace("B-", ""), word: current.word.replace("##", ""), tokens: [current], // It should be current.start but 'start' is useless in transformers.js start: current.index, end: current.index, } satisfies AggregationGroup); } return result; } const tokens = [...lastEntry.tokens, current].sort( (a, b) => a.score - b.score ); const maxToken = tokens.at(-1); const word = lastEntry.word.concat( // Include '##' means that word is part of previous, otherwise we need to add a blank space between isWordFragment ? current.word.replace("##", "") : " " + current.word ); return [ ...groups, { ...lastEntry, entityGroup: maxToken?.entity.split("-").pop() || lastEntry.entityGroup, score: maxToken?.score || lastEntry.score, word, tokens, end: current.index, } satisfies AggregationGroup, ]; }, new Array()); console.log(grouped); console.log( "MAX:", grouped // Only for my specific model in Portuguese, here I'm just filtering where is PERSONs .filter((group) => group.entityGroup === "PESSOA") .map((group) => group.word) ); return grouped; } function simpleAggregation(tokens: TokenClassificationOutput) { const grouped = tokens .filter((token) => !token.entity.startsWith("O")) .reduce((groups, current) => { if (current.entity.startsWith("B")) { return [ ...groups, { score: current.score, entityGroup: current.entity.replace("B-", ""), word: current.word, tokens: [current], start: current.index, end: current.index, } satisfies AggregationGroup, ]; } const lastEntry = groups.pop(); if (!lastEntry) { return groups; } // Discard if is not same Entity Group of last entry if (lastEntry.entityGroup !== current.entity.replace("I-", "")) { return [...groups, lastEntry]; } const tokens = [...lastEntry.tokens, current]; const score = tokens.reduce( (max, token) => Math.max(max, token.score), -Infinity ); const word = lastEntry.word.concat( // Include '##' means that word is part of previous, otherwise we need to add a blank space between current.word.includes("##") ? current.word.replace("##", "") : " " + current.word ); return [ ...groups, { ...lastEntry, score, word, tokens, end: current.index, } satisfies AggregationGroup, ]; }, new Array()); console.log(grouped); console.log( "SIMPLE:", grouped .filter((group) => group.entityGroup === "PESSOA") .map((group) => group.word) ); return grouped; } ```
NextJs app example ![image](https://github.com/xenova/transformers.js/assets/105971119/f9e8f570-d52f-49eb-aef0-8b285092cc98) > Web interface ### Simple Aggregation image > Simple aggregation Result ### Max Aggregation image > Max aggregation result ### Full code > `src/app/page.tsx` ```tsx "use client"; import { useState } from "react"; import { Button } from "@/components/ui/button"; import { Textarea } from "@/components/ui/textarea"; import { usePipeline } from "@/lib/hooks/use-pipeline"; import { TokenClassificationOutput, TokenClassificationSingle, } from "@xenova/transformers"; import { LoaderCircle } from "lucide-react"; import { Skeleton } from "@/components/ui/skeleton"; import clsx from "clsx"; type AggregationGroup = { score: number; entityGroup: string; tokens: TokenClassificationSingle[]; word: string; start: number; end: number; }; function maxAggregation(tokens: TokenClassificationOutput) { const grouped: AggregationGroup[] = tokens .filter((token) => !token.entity.startsWith("O")) .reduce((groups, current) => { const isBeginToken = current.entity.startsWith("B-"); const isWordFragment = current.word.startsWith("##"); if (isBeginToken && !isWordFragment) { return [ ...groups, { score: current.score, entityGroup: current.entity.replace("B-", ""), word: current.word, tokens: [current], // It should be current.start but 'start' is useless in transformers.js start: current.index, end: current.index, } satisfies AggregationGroup, ]; } const lastEntry = groups.pop(); if (!lastEntry) { return groups; } // Discard if index distance is too far: n+1 if (lastEntry.end + 1 < current.index) { const result = [...groups, lastEntry]; if (isBeginToken) { result.push({ score: current.score, entityGroup: current.entity.replace("B-", ""), word: current.word.replace("##", ""), tokens: [current], // It should be current.start but 'start' is useless in transformers.js start: current.index, end: current.index, } satisfies AggregationGroup); } return result; } const tokens = [...lastEntry.tokens, current].sort( (a, b) => a.score - b.score ); const maxToken = tokens.at(-1); const word = lastEntry.word.concat( // Include '##' means that word is part of previous, otherwise we need to add a blank space between isWordFragment ? current.word.replace("##", "") : " " + current.word ); return [ ...groups, { ...lastEntry, entityGroup: maxToken?.entity.split("-").pop() || lastEntry.entityGroup, score: maxToken?.score || lastEntry.score, word, tokens, end: current.index, } satisfies AggregationGroup, ]; }, new Array()); console.log(grouped); console.log( "MAX:", grouped .filter((group) => group.entityGroup === "PESSOA") .map((group) => group.word) ); return grouped; } function simpleAggregation(tokens: TokenClassificationOutput) { const grouped = tokens .filter((token) => !token.entity.startsWith("O")) .reduce((groups, current) => { if (current.entity.startsWith("B")) { return [ ...groups, { score: current.score, entityGroup: current.entity.replace("B-", ""), word: current.word, tokens: [current], start: current.index, end: current.index, } satisfies AggregationGroup, ]; } const lastEntry = groups.pop(); if (!lastEntry) { return groups; } // Discard if is not same Entity Group of last entry if (lastEntry.entityGroup !== current.entity.replace("I-", "")) { return [...groups, lastEntry]; } const tokens = [...lastEntry.tokens, current]; const score = tokens.reduce( (max, token) => Math.max(max, token.score), -Infinity ); const word = lastEntry.word.concat( // Include '##' means that word is part of previous, otherwise we need to add a blank space between current.word.includes("##") ? current.word.replace("##", "") : " " + current.word ); return [ ...groups, { ...lastEntry, score, word, tokens, end: current.index, } satisfies AggregationGroup, ]; }, new Array()); console.log(grouped); console.log( "SIMPLE:", grouped .filter((group) => group.entityGroup === "PESSOA") .map((group) => group.word) ); return grouped; } function Token({ value }: { value: TokenClassificationSingle }) { return ( {value.word.replace("##", "")} ); } export default function Home() { const [text, setText] = useState(""); const [extraction, setExtraction] = useState(); const tokenClassification = usePipeline( "token-classification", "KallebySantos/ner-bert-large-cased-pt-lenerbr-onnx" ); const isLoading = !tokenClassification.isReady || tokenClassification.isProcessing; async function HandleExtract() { if (isLoading) { console.info("worker is loading..."); return; } const outputTokens = (await tokenClassification.pipe(text, { ignore_labels: [], })) as TokenClassificationOutput; /* const grouped = outputTokens.reduce((prev, current, idx, array) => { if (current.entity.startsWith("B")) { return [...prev, [current]]; } const a = prev.pop(); if (a) { return [...prev, [...a, current]]; } return prev; }, new Array()); const mapped = grouped.map((group) => ({ group, entity: group.at(0)?.entity.replace("B-", ""), start: group.at(0)?.index, end: group.at(-1)?.index, words: group.map((item) => item.word.replace("##", "")), })); console.log(grouped); console.log(mapped); */ setExtraction(outputTokens); } return (

Insira o texto abaixo

{!tokenClassification.isReady ? ( ) : ( )}
{extraction && ( )} {extraction && ( )}
{extraction && extraction.map((token) => )}
); } ```