langchain-ai / langchainjs

🦜🔗 Build context-aware reasoning applications 🦜🔗
https://js.langchain.com/docs/
MIT License
12.16k stars 2.05k forks source link

GPT4All + langchain typescript #628

Closed sime2408 closed 1 year ago

sime2408 commented 1 year ago

I am working with typescript + langchain + pinecone and I want to use GPT4All models. Can you guys make this work? Tried import { GPT4All } from 'langchain/llms'; but with no luck.

gallynaut commented 1 year ago

I ended up not needing this in my project but here's a head start if anyone wants to polish it off

Can initialize it like this

import { GPT4All } from "./Gpt4All"

  const llm = new GPT4All({
    temp: 0,
    promptTimeout: 3000,
    threads: 8,
    ctx_size: 2048,
    n_predict: 2048,
  });
// wait for the child process to load the model
  await sleep(5000);
// Gpt4All.ts
import { ChildProcessWithoutNullStreams, spawn } from "child_process";
import { BaseLLM, BaseLLMParams } from "langchain/llms";
import { Generation, LLMResult } from "langchain/schema";
import path from "path";

interface Gpt4AllInput {
  // Path to the pre-trained GPT4All model file. (default: gpt4all-lora-quantized.bin)
  model?: string;
  verbose?: boolean;
  // size of the prompt context (default: 2048)
  ctx_size?: number;
  // RNG seed. (default: -1)
  seed?: number;
  // Number of threads to use. (default: 8)
  threads?: number;
  // The maximum number of tokens to generate. (default: 128)
  n_predict?: number;
  // The temperature to use for sampling. (default: 0.1)
  temp?: number;
  // The top-p value to use for sampling. (default: 0.9)
  top_p?: number;
  // The top-k value to use for sampling. (default: 40)
  top_k?: number;
  // Last n tokens to penalize (default: 64)
  repeat_last_n?: number;
  // The penalty to apply to repeated tokens. (default: 1.3)
  repeat_penalty?: number;
}

export class GPT4All extends BaseLLM implements Gpt4AllInput {
  binaryPath: string;
  promptTimeout: number;

  model?: string;
  verbose = false;
  ctx_size?: number;
  seed?: number;
  threads?: number;
  n_predict?: number;
  temp?: number;
  top_p?: number;
  top_k?: number;
  repeat_last_n?: number;
  repeat_penalty?: number;

  client: ChildProcessWithoutNullStreams;

  isInitialized = false;
  isReady = false;

  constructor(
    config: Partial<Gpt4AllInput> & {
      binaryPath?: string;
      promptTimeout?: number;
    } & BaseLLMParams
  ) {
    super(config ?? {});
    const binaryPath = config.binaryPath ?? process.env.GPT4ALL_BINARY_PATH;
    if (!binaryPath) {
      throw new Error(
        `model must be provided or set the env variable 'GPT4ALL_BINARY_PATH'`
      );
    }
    this.binaryPath = binaryPath;
    this.promptTimeout = Math.min(1000, config.promptTimeout || 1000);

    this.model = config.model;
    this.verbose = config.verbose ?? false;
    this.ctx_size = config.ctx_size;
    this.seed = config.seed;
    this.threads = config.threads;
    this.n_predict = config.n_predict;
    this.temp = config.temp;
    this.top_p = config.top_p;
    this.top_k = config.top_k;
    this.repeat_last_n = config.repeat_last_n;
    this.repeat_penalty = config.repeat_penalty;

    const parsedModelPath = path.parse(this.binaryPath);
    this.client = spawn(`./${parsedModelPath.base}`, {
      cwd: parsedModelPath.dir,
      stdio: "pipe",
    });

    ["SIGINT", "SIGTERM", "SIGQUIT", "exit"].forEach((signal) => {
      process.on(signal, () => {
        this.isReady = false;
        this.client.kill();
      });
    });

    this.client.stderr.on("data", (data: Buffer) => {
      const dataString = data
        .toString("utf-8") // eslint-disable-next-line no-control-regex
        .replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "");
      if (dataString.includes("== Running in chat mode. ==")) {
        this.isInitialized = true;
        this.isReady = true;
      } else if (this.isInitialized) {
        console.error(dataString);
      }
    });
  }

  // Get the identifying parameters.
  get args(): Array<string> {
    const llamaParams = new Set([
      "seed",
      "top_p",
      "top_k",
      "repeat_last_n",
      "repeat_penalty",
      "ctx_size",
      "temp",
      "batch_size",
      "threads",
    ]);
    const params: Record<string, any> = {};
    for (const [key, value] of Object.entries(this)) {
      if (value !== undefined && llamaParams.has(key)) {
        params[key] = value;
      }
    }
    const args = Array.from(Object.keys(params)).map((k) => {
      return `--${k} ${params[k]}`;
    });
    return args;
  }

  _llmType(): string {
    return "gpt4all";
  }

  /**
   * Call out to GPT4All's generate method.
   *
   * @param prompt - The prompt to pass into the model.
   * @param stop - Optional list of stop words to use when generating.
   *
   * @returns the full LLM response.
   *
   * @example
   * ```ts
   * import { GPT4All } from "./Gpt4All.ts";
import { toString } from '../../../switchboard/switchboard-oracle-v2/task-runner/src/utils/misc';
   * const gpt4All = new GPT4All();
   * const response = await gpt4All.call("Tell me a joke.")
   * ```
   */
  async _call(prompt: string, stop?: string[]): Promise<string> {
    if (!this.isInitialized) {
      throw new Error(`GPT4All model is not initialized`);
    }
    if (!this.isReady) {
      throw new Error(`GPT4All model is not ready`);
    }
    console.log(
      `###########################################\nPROMPT:\n${prompt}\n###########################################\n`
    );
    const response = await sendMessageAndWaitForResult(
      this.client,
      prompt,
      this.promptTimeout
    );
    console.log(
      `###########################################\nRESPONSE:\n${response}\n###########################################\n`
    );
    return "\n" + response;
  }

  /**
   * Call out to GPT4All's generate method.
   *
   * @param prompts - The prompts to pass into the model.
   * @param stop - Optional list of stop words to use when generating.
   *
   * @returns the full LLM output.
   *
   * @example
   * ```ts
   * import { GPT4All } from "./Gpt4All.ts";
   * const gpt4All = new GPT4All();
   * const response = await gpt4All.generate(["Tell me a joke."])
   * ```
   */
  async _generate(prompts: Array<string>, stop?: string[]): Promise<LLMResult> {
    const generations: Array<Array<Generation>> = [];
    for await (const prompt of prompts) {
      const result = await this._call(prompt, stop);
      generations.push([{ text: result }]);
    }
    return { generations };
  }
}

function sendMessageAndWaitForResult(
  childProcess: ChildProcessWithoutNullStreams,
  message: string,
  promptTimeout = 1000
): Promise<string> {
  return new Promise((resolve, reject) => {
    let result = "";
    let timeout: NodeJS.Timeout | null = null;

    const onError = (error: Error) => {
      cleanupAndReject(error);
    };

    const resetTimeout = () => {
      if (timeout) {
        clearTimeout(timeout);
      }
      timeout = setTimeout(tryCleanupAndResolve, promptTimeout);
    };

    const tryCleanupAndResolve = () => {
      if (result && !new RegExp(/^(\n)?\s?\> $/).test(result)) {
        cleanupAndResolve();
      } else {
        resetTimeout();
      }
    };

    const cleanupAndResolve = () => {
      if (timeout) {
        clearTimeout(timeout);
      }
      childProcess.stdout.off("data", onData);
      childProcess.stderr.off("error", onError);
      resolve(result.replaceAll("\\n", "\n"));
    };

    const cleanupAndReject = (error: Error) => {
      if (timeout) {
        clearTimeout(timeout);
      }
      childProcess.stdout.off("data", onData);
      childProcess.stderr.off("error", onError);
      reject(error);
    };

    const onData = (data: Buffer) => {
      const dataString = data
        .toString("utf-8")
        // eslint-disable-next-line no-control-regex
        .replace(/\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])/g, "");

      // Detect the '> ' message or set/reset a 1-second timeout to resolve the promise
      if (new RegExp(/^(\n)?\s?\> $/).test(dataString)) {
        console.log(`>>>>>>> RECEIVED: ${dataString}`);
        tryCleanupAndResolve();
      } else {
        result += dataString;
        // console.log(`>>>>>>> RESULT: ${result}`);
        resetTimeout();
      }
    };

    childProcess.stdout.on("data", onData);
    childProcess.stderr.on("data", onError);
    childProcess.stdin.write(message.replaceAll("\n", "\\n") + "\n");
  });
}
sime2408 commented 1 year ago

You must have access to GPT-4 to do this, and the models use the names as documented https://platform.openai.com/docs/models/gpt-4

@hkd987 I wanted to avoid calling OpenAI since it's costly, but thanks for your reply!

sime2408 commented 1 year ago

Thanks, @gallynaut, I am a backend dev, with some basic knowledge in typescript, and this will help me a lot. We're chatting on the discord channel of this project gpt4-pdf-chatbot-langchain on how to use it with gpt4all, feel free to join us if your interested in 🙇🏻

hatkyinc2 commented 1 year ago

hkd987

gpt4all Is a local model, the Python implementation has it https://python.langchain.com/en/latest/modules/models/llms/integrations/gpt4all.html