huggingface / huggingface.js

Utilities to use the Hugging Face Hub API
https://hf.co/docs/huggingface.js
MIT License
1.41k stars 224 forks source link

Streaming mode for the inference api #5

Open TimMikeladze opened 1 year ago

TimMikeladze commented 1 year ago

https://huggingface.co/docs/api-inference/parallelism#streaming

In order to maximize the speed of inference, instead of running many HTTP requests it will be more efficient to stream your data to the API. This will require the use of websockets on your end.

Important: A pro account is required to use and test streaming. I began a partial implementation to add streaming support several months ago. Leaving this patch below for future reference.

commit 035a2ecab05097c663887ebf12f6716d9bbac6aa
Author: TimMikeladze <tim.mikeladze@gmail.com>
Date:   Thu Sep 22 17:27:14 2022 +0300

    First pass at streaming support

diff --git a/package.json b/package.json
index 7d93ea1..f8e3232 100644
--- a/package.json
+++ b/package.json
@@ -58,6 +58,7 @@
   ],
   "devDependencies": {
     "@size-limit/preset-small-lib": "7.0.8",
+    "@types/ws": "8.5.3",
     "husky": "8.0.1",
     "size-limit": "7.0.8",
     "tsdx": "0.14.1",
@@ -68,6 +69,8 @@
     "node-notifier": ">=8.0.1"
   },
   "dependencies": {
-    "isomorphic-unfetch": "3.1.0"
+    "isomorphic-unfetch": "3.1.0",
+    "isomorphic-ws": "5.0.0",
+    "ws": "8.8.1"
   }
 }
diff --git a/src/HuggingFace.ts b/src/HuggingFace.ts
index 6909702..29f3684 100644
--- a/src/HuggingFace.ts
+++ b/src/HuggingFace.ts
@@ -1,6 +1,12 @@
 import fetch from 'isomorphic-unfetch';
+import WebSocket from 'ws';

 export type Options = {
+  /**
+   * (Default: `true`) If enabled, array arguments will be sent over a WebSocket connection and the response will be streamed back.
+   */
+  use_streaming?: boolean;
+
   /**
    * (Default: false). Boolean to use GPU instead of CPU for inference (requires Startup plan at least).
    */
@@ -21,7 +27,14 @@ export type Options = {
 };

 export type Args = {
+  /**
+   * The name of the HuggingFace model to use.
+   */
   model: string;
+  /**
+   * When `use_streaming` option is enabled this id will be included in each response to identify the request.
+   */
+  id?: string;
 };

 export type FillMaskArgs = Args & {
@@ -391,9 +404,9 @@ export class HuggingFace {
    * Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
    */
   public async fillMask(
-    args: FillMaskArgs,
+    args: FillMaskArgs | FillMaskArgs[],
     options?: Options
-  ): Promise<FillMaskReturn> {
+  ): Promise<FillMaskReturn | FillMaskReturn[]> {
     return this.request(args, options);
   }

@@ -401,19 +414,33 @@ export class HuggingFace {
    * This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
    */
   public async summarization(
-    args: SummarizationArgs,
+    args: SummarizationArgs | SummarizationArgs[],
     options?: Options
-  ): Promise<SummarizationReturn> {
+  ): Promise<SummarizationReturn | SummarizationReturn[]> {
     return (await this.request(args, options))?.[0];
   }

+  /**
+   * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
+   */
+  public async questionAnswer(
+    args: QuestionAnswerArgs[],
+    options?: Options
+  ): Promise<QuestionAnswerReturn[]>;
   /**
    * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
    */
   public async questionAnswer(
     args: QuestionAnswerArgs,
     options?: Options
-  ): Promise<QuestionAnswerReturn> {
+  ): Promise<QuestionAnswerReturn>;
+  /**
+   * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
+   */
+  public async questionAnswer(
+    args: QuestionAnswerArgs | QuestionAnswerArgs[],
+    options?: Options
+  ): Promise<QuestionAnswerReturn | QuestionAnswerReturn[]> {
     return await this.request(args, options);
   }

@@ -421,9 +448,9 @@ export class HuggingFace {
    * Don’t know SQL? Don’t want to dive into a large spreadsheet? Ask questions in plain english! Recommended model: google/tapas-base-finetuned-wtq.
    */
   public async tableQuestionAnswer(
-    args: TableQuestionAnswerArgs,
+    args: TableQuestionAnswerArgs | TableQuestionAnswerArgs[],
     options?: Options
-  ): Promise<TableQuestionAnswerReturn> {
+  ): Promise<TableQuestionAnswerReturn | TableQuestionAnswerReturn[]> {
     return await this.request(args, options);
   }

@@ -431,9 +458,9 @@ export class HuggingFace {
    * Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
    */
   public async textClassification(
-    args: TextClassificationArgs,
+    args: TextClassificationArgs | TextClassificationArgs[],
     options?: Options
-  ): Promise<TextClassificationReturn> {
+  ): Promise<TextClassificationReturn | TextClassificationReturn[]> {
     return await this.request(args, options);
   }

@@ -441,9 +468,9 @@ export class HuggingFace {
    * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
    */
   public async textGeneration(
-    args: TextGenerationArgs,
+    args: TextGenerationArgs | TextGenerationArgs[],
     options?: Options
-  ): Promise<TextGenerationReturn> {
+  ): Promise<TextGenerationReturn | TextGenerationReturn[]> {
     return (await this.request(args, options))?.[0];
   }

@@ -451,9 +478,9 @@ export class HuggingFace {
    * Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
    */
   public async tokenClassification(
-    args: TokenClassificationArgs,
+    args: TokenClassificationArgs | TokenClassificationArgs[],
     options?: Options
-  ): Promise<TokenClassificationReturn> {
+  ): Promise<TokenClassificationReturn | TokenClassificationReturn[]> {
     return HuggingFace.toArray(await this.request(args, options));
   }

@@ -461,9 +488,9 @@ export class HuggingFace {
    * This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
    */
   public async translation(
-    args: TranslationArgs,
+    args: TranslationArgs | TranslationArgs[],
     options?: Options
-  ): Promise<TranslationReturn> {
+  ): Promise<TranslationReturn | TranslationReturn[]> {
     return (await this.request(args, options))?.[0];
   }

@@ -471,9 +498,9 @@ export class HuggingFace {
    * This task is super useful to try out classification with zero code, you simply pass a sentence/paragraph and the possible labels for that sentence, and you get a result. Recommended model: facebook/bart-large-mnli.
    */
   public async zeroShotClassification(
-    args: ZeroShotClassificationArgs,
+    args: ZeroShotClassificationArgs | ZeroShotClassificationArgs[],
     options?: Options
-  ): Promise<ZeroShotClassificationReturn> {
+  ): Promise<ZeroShotClassificationReturn | ZeroShotClassificationReturn[]> {
     return HuggingFace.toArray(await this.request(args, options));
   }

@@ -482,9 +509,9 @@ export class HuggingFace {
    *
    */
   public async conversational(
-    args: ConversationalArgs,
+    args: ConversationalArgs | ConversationalArgs[],
     options?: Options
-  ): Promise<ConversationalReturn> {
+  ): Promise<ConversationalReturn | ConversationalReturn[]> {
     return await this.request(args, options);
   }

@@ -492,43 +519,111 @@ export class HuggingFace {
    * This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
    */
   public async featureExtraction(
-    args: FeatureExtractionArgs,
+    args: FeatureExtractionArgs | FeatureExtractionArgs[],
     options?: Options
-  ): Promise<FeatureExtractionReturn> {
+  ): Promise<FeatureExtractionReturn | FeatureExtractionReturn[]> {
     return await this.request(args, options);
   }

-  public async request(args: Args, options?: Options): Promise<any> {
+  public async request(
+    args: Args | Args[],
+    options?: Options
+  ): Promise<any | any[]> {
     const mergedOptions = { ...this.defaultOptions, ...options };
-    const { model, ...otherArgs } = args;
-    const response = await fetch(
-      `https://api-inference.huggingface.co/models/${model}`,
-      {
-        headers: { Authorization: `Bearer ${this.apiKey}` },
-        method: 'POST',
-        body: JSON.stringify({
-          ...otherArgs,
-          options: mergedOptions,
-        }),
+
+    if (Array.isArray(args) && options?.use_streaming !== false) {
+      const models = new Set(args.map(x => x.model));
+
+      if (models.size > 1) {
+        throw new Error(
+          'You can only send use one model per request when the `use_streaming` option is enabled. Please group your requests by model.'
+        );
       }
-    );
-
-    if (
-      mergedOptions.retry_on_error !== false &&
-      response.status === 503 &&
-      !mergedOptions.wait_for_model
-    ) {
-      return this.request(args, {
-        ...mergedOptions,
-        wait_for_model: true,
+
+      const model = args[0].model;
+
+      const uniqueIds = args
+        .map(x => x.id)
+        .filter(x => x !== undefined && x !== null && x?.trim() !== '');
+
+      if (uniqueIds.length !== new Set(uniqueIds).size) {
+        throw new Error('Duplicate ids found in args');
+      }
+
+      const ws = new WebSocket(
+        `wss://api-inference.huggingface.co/bulk/stream/cpu/${model}`
+      );
+
+      // @ts-ignore
+      const responses: any[] = [];
+
+      // @ts-ignore
+      return new Promise((resolve, reject) => {
+        ws.on('open', () => {
+          ws.send(`Bearer ${this.apiKey}`, { binary: true });
+
+          for (const arg of args) {
+            ws.send(JSON.stringify(arg), { binary: true });
+          }
+        });
+
+        ws.on('message', (data: any) => {
+          console.log(Buffer.from(data).toString());
+          console.log(data);
+          // const message = JSON.parse(data);
+          // if (message.type == 'results') {
+          //   responses.push(message);
+          //   if (responses.length === args.length) {
+          //     ws.close();
+          //     resolve(responses);
+          //   }
+          // }
+        });
+        ws.on('error', message => {
+          console.log(message);
+          reject(message);
+        });
       });
-    }
+    } else {
+      const httpRequest = async (args: Args) => {
+        const { model, ...otherArgs } = args;
+
+        const response = await fetch(
+          `https://api-inference.huggingface.co/models/${model}`,
+          {
+            headers: { Authorization: `Bearer ${this.apiKey}` },
+            method: 'POST',
+            body: JSON.stringify({
+              ...otherArgs,
+              options: mergedOptions,
+            }),
+          }
+        );
+
+        if (
+          mergedOptions.retry_on_error !== false &&
+          response.status === 503 &&
+          !mergedOptions.wait_for_model
+        ) {
+          return this.request(args, {
+            ...mergedOptions,
+            wait_for_model: true,
+          });
+        }
+
+        const res = await response.json();
+        if (res.error) {
+          throw new Error(res.error);
+        }
+        return res;
+      };
+
+      if (Array.isArray(args)) {
+        return Promise.all(args.map(x => httpRequest(x)));
+      }

-    const res = await response.json();
-    if (res.error) {
-      throw new Error(res.error);
+      return httpRequest(args);
     }
-    return res;
   }

   private static toArray(obj: any): any[] {
diff --git a/test/HuggingFace.test.ts b/test/HuggingFace.test.ts
index ed7a592..da0a2bd 100644
--- a/test/HuggingFace.test.ts
+++ b/test/HuggingFace.test.ts
@@ -7,7 +7,7 @@ describe('HuggingFace', () => {
   // Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error.
   let hf = new HuggingFace(process.env.HF_API_KEY as string);

-  it('throws error if model does not exist', () => {
+  xit('throws error if model does not exist', () => {
     expect(
       hf.fillMask({
         model: 'this-model-does-not-exist-123',
@@ -17,7 +17,24 @@ describe('HuggingFace', () => {
       `Model this-model-does-not-exist-123 does not exist`
     );
   });
-  it('fillMask', async () => {
+  xit('throws error if multiple models are provided and use_streaming is true', () => {
+    expect(
+      hf.fillMask([
+        {
+          model: 'this-model-does-not-exist-123',
+          inputs: '[MASK] world!',
+        },
+        {
+          model: 'this-model-also-does-not-exist-123',
+          inputs: '[MASK] world!',
+        },
+      ])
+    ).rejects.toThrowError(
+      `Model this-model-does-not-exist-123 does not exist`
+    );
+  });
+
+  xit('fillMask', async () => {
     expect(
       await hf.fillMask({
         model: 'bert-base-uncased',
@@ -34,7 +51,7 @@ describe('HuggingFace', () => {
       ])
     );
   });
-  it('summarization', async () => {
+  xit('summarization', async () => {
     expect(
       await hf.summarization({
         model: 'facebook/bart-large-cnn',
@@ -49,7 +66,7 @@ describe('HuggingFace', () => {
         'The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world.',
     });
   });
-  it('questionAnswer', async () => {
+  xit('questionAnswer', async () => {
     expect(
       await hf.questionAnswer({
         model: 'deepset/roberta-base-squad2',
@@ -65,7 +82,7 @@ describe('HuggingFace', () => {
       end: expect.any(Number),
     });
   });
-  it('table question answer', async () => {
+  xit('table question answer', async () => {
     expect(
       await hf.tableQuestionAnswer({
         model: 'google/tapas-base-finetuned-wtq',
@@ -90,7 +107,7 @@ describe('HuggingFace', () => {
       aggregator: 'AVERAGE',
     });
   });
-  it('textClassification', async () => {
+  xit('textClassification', async () => {
     expect(
       await hf.textClassification({
         model: 'distilbert-base-uncased-finetuned-sst-2-english',
@@ -105,7 +122,7 @@ describe('HuggingFace', () => {
       ])
     );
   });
-  it('textGeneration', async () => {
+  xit('textGeneration', async () => {
     expect(
       await hf.textGeneration({
         model: 'gpt2',
@@ -116,7 +133,7 @@ describe('HuggingFace', () => {
         'The answer to the universe is not a binary number that is at a certain point defined in our theory of time, but an infinite number of infinitely long points and points for which each of these points has the given form in our equation. If the given',
     });
   });
-  it(`tokenClassification`, async () => {
+  xit('tokenClassification', async () => {
     expect(
       await hf.tokenClassification({
         model: 'dbmdz/bert-large-cased-finetuned-conll03-english',
@@ -134,7 +151,7 @@ describe('HuggingFace', () => {
       ])
     );
   });
-  it(`translation`, async () => {
+  xit('translation', async () => {
     expect(
       await hf.translation({
         model: 'Helsinki-NLP/opus-mt-ru-en',
@@ -144,7 +161,7 @@ describe('HuggingFace', () => {
       translation_text: 'My name is Wolfgang and I live in Berlin.',
     });
   });
-  it(`zeroShotClassification`, async () => {
+  xit('zeroShotClassification', async () => {
     expect(
       await hf.zeroShotClassification({
         model: 'facebook/bart-large-mnli',
@@ -164,7 +181,7 @@ describe('HuggingFace', () => {
       ])
     );
   });
-  it(`conversational`, async () => {
+  xit('conversational', async () => {
     expect(
       await hf.conversational({
         model: 'microsoft/DialoGPT-large',
@@ -191,7 +208,7 @@ describe('HuggingFace', () => {
       ],
     });
   });
-  it(`featureExtraction`, async () => {
+  xit('featureExtraction', async () => {
     expect(
       await hf.featureExtraction({
         model: 'sentence-transformers/paraphrase-xlm-r-multilingual-v1',
@@ -206,4 +223,69 @@ describe('HuggingFace', () => {
       })
     ).toEqual([0.6623499393463135, 0.9382339715957642, 0.22963346540927887]);
   });
+
+  xit('use http for array input when use_streaming is false', async () => {
+    const res = await hf.questionAnswer(
+      [
+        {
+          model: 'deepset/roberta-base-squad2',
+          inputs: {
+            question: 'What is the capital of France?',
+            context: 'The capital of France is Paris.',
+          },
+        },
+        {
+          model: 'deepset/roberta-base-squad2',
+          inputs: {
+            question: 'What is the capital of England?',
+            context: 'The capital of England is London.',
+          },
+        },
+      ],
+      {
+        use_streaming: false,
+      }
+    );
+
+    expect(res).toHaveLength(2);
+
+    expect(res[0]).toEqual({
+      answer: 'Paris',
+      score: expect.any(Number),
+      start: expect.any(Number),
+      end: expect.any(Number),
+    });
+    expect(res[1]).toEqual({
+      answer: 'London',
+      score: expect.any(Number),
+      start: expect.any(Number),
+      end: expect.any(Number),
+    });
+  });
+
+  it('use websockets for array input when use_streaming is true', async () => {
+    const res = await hf.questionAnswer(
+      [
+        {
+          model: 'deepset/roberta-base-squad2',
+          inputs: {
+            question: 'What is the capital of France?',
+            context: 'The capital of France is Paris.',
+          },
+        },
+        {
+          model: 'deepset/roberta-base-squad2',
+          inputs: {
+            question: 'What is the capital of England?',
+            context: 'The capital of England is London.',
+          },
+        },
+      ],
+      {
+        use_streaming: true,
+      }
+    );
+
+    console.log(res);
+  });
 });
diff --git a/yarn.lock b/yarn.lock
index c7cfe8d..6033ed8 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -1404,6 +1404,13 @@
   resolved "https://registry.yarnpkg.com/@types/stack-utils/-/stack-utils-1.0.1.tgz#0a851d3bd96498fa25c33ab7278ed3bd65f06c3e"
   integrity sha512-l42BggppR6zLmpfU6fq9HEa2oGPEI8yrSPL3GITjfRInppYFahObbIQOQK3UGxEnyQpltZLaPe75046NOZQikw==

+"@types/ws@8.5.3":
+  version "8.5.3"
+  resolved "https://registry.yarnpkg.com/@types/ws/-/ws-8.5.3.tgz#7d25a1ffbecd3c4f2d35068d0b283c037003274d"
+  integrity sha512-6YOoWjruKj1uLf3INHH7D3qTXwFfEsg1kf3c0uDdSBJwfa/llkwIjrAGV7j7mVgGNbzTQ3HiHKKDXl6bJPD97w==
+  dependencies:
+    "@types/node" "*"
+
 "@types/yargs-parser@*":
   version "21.0.0"
   resolved "https://registry.yarnpkg.com/@types/yargs-parser/-/yargs-parser-21.0.0.tgz#0c60e537fa790f5f9472ed2776c2b71ec117351b"
@@ -3842,6 +3849,11 @@ isomorphic-unfetch@3.1.0:
     node-fetch "^2.6.1"
     unfetch "^4.2.0"

+isomorphic-ws@5.0.0:
+  version "5.0.0"
+  resolved "https://registry.yarnpkg.com/isomorphic-ws/-/isomorphic-ws-5.0.0.tgz#e5529148912ecb9b451b46ed44d53dae1ce04bbf"
+  integrity sha512-muId7Zzn9ywDsyXgTIafTry2sV3nySZeUDe6YedVd1Hvuuep5AsIlqK+XefWpYTyJG5e503F2xIuT2lcU6rCSw==
+
 isstream@~0.1.2:
   version "0.1.2"
   resolved "https://registry.yarnpkg.com/isstream/-/isstream-0.1.2.tgz#47e63f7af55afa6f92e1500e690eb8b8529c099a"
@@ -6651,6 +6663,11 @@ write@1.0.3:
   dependencies:
     mkdirp "^0.5.1"

+ws@8.8.1:
+  version "8.8.1"
+  resolved "https://registry.yarnpkg.com/ws/-/ws-8.8.1.tgz#5dbad0feb7ade8ecc99b830c1d77c913d4955ff0"
+  integrity sha512-bGy2JzvzkPowEJV++hF07hAD6niYSr0JzBNo/J29WsB57A2r7Wlc1UFcTR9IzrPvuNVO4B8LGqF8qcpsVOhJCA==
+
 ws@^7.0.0:
   version "7.5.8"
   resolved "https://registry.yarnpkg.com/ws/-/ws-7.5.8.tgz#ac2729881ab9e7cbaf8787fe3469a48c5c7f636a"
coyotte508 commented 1 year ago

The API could be like this:

const stream = new HfInference().batch().textGeneration([...] | ...).textToImage([...] | ...);

for await (const output of stream) {

}

The .batch() would be a clear delimiter between streaming & non-streaming endpoints, and the results would be gotten one by one as soon as available.

We can also allow async iterable as parameters (rather than arrays), so there can be an upload stream too. We should also add a parameter to batch, defining the parallel inferences:

const stream = new HfInference.batch({concurrency: X})

So it only sends X data at a time. Like the concurrency param of promisesQueueStreaming. Default value would need to consider both inference API and inference endpoints.

Maybe we should do two functions:

apogadaev commented 1 year ago

@coyotte508 Hi! Is stream/bulk feature implemented? Because I have an issue here: https://github.com/huggingface/api-inference-community/issues/194#issuecomment-1513409183

coyotte508 commented 1 year ago

It's not implemented client-side but should be serverside. The issue you linked is filed in the correct place, you should get an answer in the coming days :)

bkuermayr commented 1 year ago

I have the same issue huggingface/api-inference-community#194. Do you know when this streaming feature is expected to work?