langchain-ai / langchainjs

🦜🔗 Build context-aware reasoning applications 🦜🔗
https://js.langchain.com/docs/
MIT License
12.32k stars 2.09k forks source link

`RunnableWithFallbacks` doesn't implement streaming support #6026

Closed tomi-bigpi closed 1 month ago

tomi-bigpi commented 2 months ago

Checked other resources

Example Code

Sample code similar to the following should stream the results:

import { BedrockChat } from "@langchain/community/chat_models/bedrock";
import { concat } from "@langchain/core/utils/stream";
import type { AIMessageChunk } from "@langchain/core/messages";

function createModelWithFallbacks() {
  const primaryModel = new BedrockChat({
    model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
    region: "us-east-1",
  });

   const secondaryModel = new BedrockChat({
     model: "anthropic.claude-3-haiku-20240307-v1:0",
     region: "us-east-1",
   });

  // Using `withFallbacks` causes streaming to not work
  return primaryModel.withFallbacks({
    fallbacks: [
      secondaryModel,
    ],
  });
}

export async function testStream() {
  const model = createModelWithFallbacks();

  // This will result in a call to `.invoke()` instead
  const stream = await model.stream("What color is the sky on Mars?");

  let gathered: AIMessageChunk | undefined = undefined;

  for await (const chunk of stream) {
    console.log(chunk);
    if (gathered === undefined) {
      gathered = chunk;
    } else {
      gathered = concat(gathered, chunk);
    }
  }

  console.log(gathered);
}

testStream();

However, currently invoke will be called on the model(s) and no streaming happens.

A sample patch to resolves the issue:

diff --git a/dist/runnables/base.cjs b/dist/runnables/base.cjs
index 2276cfd01b0ad7072094f634bbe7ec115189f20e..cf87a60ff48ceb57b6bcc0abfbf9e7db27a9c5f5 100644
--- a/dist/runnables/base.cjs
+++ b/dist/runnables/base.cjs
@@ -1699,6 +1699,73 @@ class RunnableWithFallbacks extends Runnable {
         await runManager?.handleChainError(firstError);
         throw firstError;
     }
+    async *_streamIterator(input, options) {
+        const callbackManager_ = await manager_js_1.CallbackManager.configure(options?.callbacks, undefined, options?.tags, undefined, options?.metadata);
+        const { runId, ...otherOptions } = options ?? {};
+        const runManager = await callbackManager_?.handleChainStart(this.toJSON(), _coerceToDict(input, "input"), runId, undefined, undefined, undefined, otherOptions?.runName);
+
+        let chunk;
+        let firstError = undefined;
+        let stream = undefined;
+        let concatSupported = false;
+        let finalOutput = undefined;
+        for await (const runnable of this.runnables()) {
+            concatSupported = true;
+            try {
+                stream = await runnable.stream(input, config_js_1.patchConfig(otherOptions, { callbacks: runManager?.getChild() }));
+                chunk = stream.next().value;
+
+                // Clear the error since we successfully got a chunk from this runnable
+                firstError = undefined;
+            } catch (e) {
+                if (firstError === undefined) {
+                    firstError = e;
+                }
+            }
+
+            // Check if we successfully got a chunk from this runnable
+            if (!firstError) {
+                // Use the stream from this runnable for the rest of the chunks
+                break;
+            }
+        }
+
+        if (firstError) {
+            await runManager?.handleChainError(firstError);
+            throw firstError;
+        }
+
+        if (chunk) {
+            yield chunk;
+        }
+
+        if (stream) {
+            try {
+                for await (chunk of stream) {
+                    if (concatSupported) {
+                        if (finalOutput === undefined) {
+                            finalOutput = chunk;
+                        } else {
+                            try {
+                                // eslint-disable-next-line @typescript-eslint/no-explicit-any
+                                finalOutput = concat(finalOutput, chunk);
+                            }
+                            catch (e) {
+                                finalOutput = undefined;
+                                concatSupported = false;
+                            }
+                        }
+                    }
+                    yield chunk;
+                }
+            } catch (e) {
+                await runManager?.handleChainError(e);
+                throw e;
+            }
+        }
+
+        await runManager?.handleChainEnd(finalOutput);
+    }
     async batch(inputs, options, batchOptions) {
         if (batchOptions?.returnExceptions) {
             throw new Error("Not implemented.");
diff --git a/dist/runnables/base.js b/dist/runnables/base.js
index d60358b742c78df8cb1edc7d3a872a1c383bc3ed..76e5ab8fcafbb5ffe134929411c09300970321b6 100644
--- a/dist/runnables/base.js
+++ b/dist/runnables/base.js
@@ -1683,6 +1683,73 @@ export class RunnableWithFallbacks extends Runnable {
         await runManager?.handleChainError(firstError);
         throw firstError;
     }
+    async *_streamIterator(input, options) {
+        const callbackManager_ = await CallbackManager.configure(options?.callbacks, undefined, options?.tags, undefined, options?.metadata);
+        const { runId, ...otherOptions } = options ?? {};
+        const runManager = await callbackManager_?.handleChainStart(this.toJSON(), _coerceToDict(input, "input"), runId, undefined, undefined, undefined, otherOptions?.runName);
+
+        let chunk;
+        let firstError = undefined;
+        let stream = undefined;
+        let concatSupported = false;
+        let finalOutput = undefined;
+        for await (const runnable of this.runnables()) {
+            concatSupported = true;
+            try {
+                stream = await runnable.stream(input, patchConfig(otherOptions, { callbacks: runManager?.getChild() }));
+                chunk = stream.next().value;
+
+                // Clear the error since we successfully got a chunk from this runnable
+                firstError = undefined;
+            } catch (e) {
+                if (firstError === undefined) {
+                    firstError = e;
+                }
+            }
+
+            // Check if we successfully got a chunk from this runnable
+            if (!firstError) {
+                // Use the stream from this runnable for the rest of the chunks
+                break;
+            }
+        }
+
+        if (firstError) {
+            await runManager?.handleChainError(firstError);
+            throw firstError;
+        }
+
+        if (chunk) {
+            yield chunk;
+        }
+
+        if (stream) {
+            try {
+                for await (chunk of stream) {
+                    if (concatSupported) {
+                        if (finalOutput === undefined) {
+                            finalOutput = chunk;
+                        } else {
+                            try {
+                                // eslint-disable-next-line @typescript-eslint/no-explicit-any
+                                finalOutput = concat(finalOutput, chunk);
+                            }
+                            catch (e) {
+                                finalOutput = undefined;
+                                concatSupported = false;
+                            }
+                        }
+                    }
+                    yield chunk;
+                }
+            } catch (e) {
+                await runManager?.handleChainError(e);
+                throw e;
+            }
+        }
+
+        await runManager?.handleChainEnd(finalOutput);
+    }
     async batch(inputs, options, batchOptions) {
         if (batchOptions?.returnExceptions) {
             throw new Error("Not implemented.");

Error Message and Stack Trace (if applicable)

No response

Description

(Modified description of the situation based on dosubot's description from a similar issue:) The stream method implementation provided in the base class (from which other runnables inherit) is designed to work with individual Runnable instances. It leverages an asynchronous generator (_streamIterator) to yield results as they become available.

The RunnableWithFallbacks and similar constructs do not override or provide their own implementations of the stream method that would allow them to handle streaming appropriately across all constituent runnables. Without such an implementation, attempting to stream through a RunnableWithFallbacks defaults to the base class's streaming behavior, which does not implement streaming/call streaming on the actual primary or fallback models.

We did a quick local patch of the built files. I've added the patch code above even though it's not the best way to fix the issue. The patch code is just for reference.

System Info

> rush-pnpm show langchain

langchain@0.2.9 | MIT | deps: 16 | versions: 281
Typescript bindings for langchain
https://github.com/langchain-ai/langchainjs/tree/main/langchain/

keywords: llm, ai, gpt3, chain, prompt, prompt engineering, chatgpt, machine learning, ml, openai, embeddings, vectorstores

dist
.tarball: https://registry.npmjs.org/langchain/-/langchain-0.2.9.tgz
.shasum: 1341bdd7166f4f6da0b9337f363e409a79523dbb
.integrity: sha512-iZ0l7BDVfoifqZlDl1gy3JP5mIdhYjWiToPlDnlmfHD748cw3okvF0gZo0ruT4nbftnQcaM7JzPUiNC43UPfgg==
.unpackedSize: 4.0 MB

dependencies:
@langchain/core: >=0.2.11 <0.3.0  @langchain/textsplitters: ~0.0.0  js-tiktoken: ^1.0.12              jsonpointer: ^5.0.1               langsmith: ~0.1.30                openapi-types: ^12.1.3            uuid: ^10.0.0                     zod-to-json-schema: ^3.22.3
@langchain/openai: >=0.1.0 <0.3.0 binary-extensions: ^2.2.0         js-yaml: ^4.1.0                   langchainhub: ~0.0.8              ml-distance: ^4.0.0               p-retry: 4                        yaml: ^2.2.1                      zod: ^3.22.4

maintainers:
- nfcampos <nuno@boringbits.io>
- jacoblee93 <jacoblee93@gmail.com>
- andrewnguonly <andrewnguonly@gmail.com>
- davidduong <david@duong.cz>
- hwchase17 <hw.chase.17@gmail.com>
- basproul <braceasproul@gmail.com>

dist-tags:
latest: 0.2.9     next: 0.2.3-rc.0

published 13 hours ago by jacoblee93 <jacoblee93@gmail.com>

Platform: mac

> node --version
Node v20.11.1
> rush-pnpm --version
pnpm 8.15.4
dosubot[bot] commented 2 months ago

Hey there, @tomi-bigpi! I'm here to help you out with any bugs, questions, or contributions you have in mind. Let's tackle this together!

To implement streaming support for the RunnableWithFallbacks class in LangChain.js, you need to override the stream method to handle the streaming logic, including the fallback mechanism. Here is how you can do it:

  1. Override the stream method to handle the primary runnable and fallbacks.
  2. Ensure that the method yields chunks from the primary runnable and, if it fails, from the fallback runnables.

Here is the implementation:

export class RunnableWithFallbacks<RunInput, RunOutput> extends Runnable<
  RunInput,
  RunOutput
> {
  static lc_name() {
    return "RunnableWithFallbacks";
  }

  lc_namespace = ["langchain_core", "runnables"];

  lc_serializable = true;

  runnable: Runnable<RunInput, RunOutput>;

  fallbacks: Runnable<RunInput, RunOutput>[];

  constructor(fields: {
    runnable: Runnable<RunInput, RunOutput>;
    fallbacks: Runnable<RunInput, RunOutput>[];
  }) {
    super(fields);
    this.runnable = fields.runnable;
    this.fallbacks = fields.fallbacks;
  }

  *runnables() {
    yield this.runnable;
    for (const fallback of this.fallbacks) {
      yield fallback;
    }
  }

  async *stream(
    input: RunInput,
    options?: Partial<RunnableConfig>
  ): AsyncGenerator<RunOutput> {
    const runnables = this.runnables();
    for (const runnable of runnables) {
      try {
        const stream = await runnable.stream(input, options);
        for await (const chunk of stream) {
          yield chunk;
        }
        return; // Exit if the primary runnable succeeds
      } catch (e) {
        // Log the error and continue to the next fallback
        console.error(`Runnable failed: ${e.message}`);
      }
    }
    throw new Error("All runnables failed.");
  }
}

This implementation ensures that the RunnableWithFallbacks class attempts to stream output from the primary runnable first and falls back to the other runnables if the primary one fails [1].

To continue talking to Dosu, mention @dosu.

tomi-bigpi commented 2 months ago

@dosu Seems like the streaming support should be implemented within Langchain, not by consumers. Please forward to a human for review.

dosubot[bot] commented 2 months ago

@jacoblee93, this one's a bit out of my wheelhouse, could you step in and provide your guidance?

To continue talking to Dosu, mention @dosu.

jacoblee93 commented 2 months ago

So this will basically check if pulling the initial chunk throws an error, and falls back if it does? Otherwise, if there's a streaming error mid-stream, it'll throw?

tomi-bigpi commented 2 months ago

Yes, that code is very similar to the Python implementation (which should be used as the basis for a real implementation for this - this version was a bit cobbled together and run through LLM + modeled after other methods in the Runnable* classes). The main idea is to keep falling back as long as there are fallback models available, and then actually throw if the last one fails.

jacoblee93 commented 2 months ago

Seems reasonable - if you're up for making it a PR would appreciate that otherwise I can when I get a minute.

jsegeren commented 1 month ago

Just to comment - we would greatly value this fix/improvement as well.

jacoblee93 commented 1 month ago

Live in core@0.2.23