zuisong / gemini-openai-proxy

OpenAI to Google Gemini https://gemini-openai-proxy.deno.dev
https://gemini-openai-proxy.zuisong.workers.dev
MIT License
306 stars 88 forks source link

embbending #62

Closed traderpedroso closed 3 months ago

traderpedroso commented 4 months ago

are you going to add embbending end point ?

DongqiShen commented 3 months ago

I wrote an embedding repo that can be deployed on a Cloudflare worker.

ekatiyar commented 3 months ago

I wanted an all-in-one solution for both chat completion and embeddings - I ended up creating a fork of a similar project with embeddings support added here. My main branch includes a change to eliminate openAI -> gemini model mapping, but if you're only interested in added embeddings support I have that on the embedding feature branch

zuisong commented 3 months ago

I don' use this endpoint. I submitted a PR to initially implement this function. Feedback is welcome. 😁

zuisong commented 3 months ago

The first version has been merged into the master branch, and it can be used like this currently.

curl https://gemini-openai-proxy.deno.dev/v1/embeddings \
        -H "Content-Type: application/json" \
        -H "Authorization: Bearer $YOUR_GEMINI_API_KEY" \
        -d '{
      "input": "Your text string goes here",
      "model": "text-embedding-3-small"
    }'
traderpedroso commented 2 months ago

The first version has been merged into the master branch, and it can be used like this currently.

curl https://gemini-openai-proxy.deno.dev/v1/embeddings \
        -H "Content-Type: application/json" \
        -H "Authorization: Bearer $YOUR_GEMINI_API_KEY" \
        -d '{
      "input": "Your text string goes here",
      "model": "text-embedding-3-small"
    }'

I incorporated your code and successfully integrated DuckDuckGo models. This expands the available models beyond Gemini, as DuckDuckGo models are now deployed on Cloudflare workers. The gemini-api-client.ts file and the list of models are located in utils.ts.

import { EventSourceParserStream } from "eventsource-parser/stream";
import type { ApiParam, GeminiModel } from "../utils.ts";
import { GoogleGenerativeAIError } from "./errors.ts";
import type { EnhancedGenerateContentResponse, GenerateContentRequest, GenerateContentResponse, GenerateContentResult, RequestOptions } from "./types.ts";
import { addHelpers } from "./response-helper.ts";
type Messages = { content: string | object; role: "user" | "assistant" }[];
// Modelos que devem usar DuckDuckGo
const duckDuckGoModels = new Set([
  "meta-llama/Llama-3-70b-chat-hf",
  "gpt-3.5-turbo-0125",
  "claude-3-haiku-20240307",
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
  "gpt-4o-mini",
]);

export async function* generateContent(
  apiParam: ApiParam,
  model: GeminiModel,
  params: GenerateContentRequest,
  requestOptions?: RequestOptions,
): AsyncGenerator<GenerateContentResult> {
  if (duckDuckGoModels.has(model)) {
    yield* generateContentDuckDuckGo(apiParam, model, params, requestOptions);
  } else {
    yield* generateContentGoogle(apiParam, model, params, requestOptions);
  }
}

async function* generateContentGoogle(
  apiParam: ApiParam,
  model: GeminiModel,
  params: GenerateContentRequest,
  requestOptions?: RequestOptions,
): AsyncGenerator<GenerateContentResult> {
  const url = new RequestUrl(model, Task.STREAM_GENERATE_CONTENT, true, apiParam);
  const fetchResponse = await makeRequest(url, JSON.stringify(params), requestOptions);
  const body = fetchResponse.body;
  if (body == null) {
    return;
  }

  for await (const event of body.pipeThrough(new TextDecoderStream()).pipeThrough(new EventSourceParserStream())) {
    const responseJson: GenerateContentResponse = JSON.parse(event.data);
    const enhancedResponse = addHelpers(responseJson);
    yield {
      response: enhancedResponse,
    };
  }
}

function addHelpersDuckDuckGo(response: any): EnhancedGenerateContentResponse {
  return {
    candidates: [
      {
        content: {
          parts: [{ text: response.message }]
        },
        finishReason: response.finish_reason || 'STOP',
        index: response.index || 0,
        safetyRatings: response.safety_ratings || [],
      }
    ],
    usageMetadata: {
      promptTokenCount: response.usage?.prompt_tokens || 0,
      candidatesTokenCount: response.usage?.completion_tokens || 0,
      totalTokenCount: response.usage?.total_tokens || 0,
    },
    result: function() {
      if (this.candidates && this.candidates.length > 0) {
        if (this.candidates.length > 1) {
          console.warn(
            `This response had ${this.candidates.length} candidates. Returning text from the first candidate only. Access response.candidates directly to use the other candidates.`,
          );
        }
        return this.candidates[0].content.parts[0].text;
      }
      return "";
    }
  };
}

async function* generateContentDuckDuckGo(
  apiParam: ApiParam,
  model: GeminiModel,
  params: GenerateContentRequest,
  requestOptions?: RequestOptions,
): AsyncGenerator<GenerateContentResult> {
  // Separar mensagens de 'user' e  'model'
  const userMessages = params.contents.filter(content => content.role === 'user');
  const modelMessages = params.contents.filter(content => content.role === 'model');
  const systemMessages = params.contents.filter(content => content.role === 'system');

  // Concatenar mensagens de 'model' ao início da primeira mensagem de 'user'
  const combinedMessages = modelMessages.map(content => content.parts.map(part => part.text).join(" ")).join(" ");
  let finalMessages: Messages = []; 

  if (userMessages.length > 0) { // Garante que haja ao menos uma mensagem de usuário
    finalMessages = userMessages.map(content => ({
      role: 'user',
      content: combinedMessages + ' ' + content.parts.map(part => part.text).join(" ")
    }));
  }

  // Inicializa o x-vqd-4
  let vqd = ""; 

  // Obter o `x-vqd-4` da API
  const statusResponse = await fetch('https://duckduckgo.com/duckchat/v1/status', {
    method: 'GET',
    headers: {
      'x-vqd-accept': '1'
    }
  });

  if (!statusResponse.ok) {
    throw new Error(`Failed to fetch status from DuckDuckGo`);
  }

  vqd = statusResponse.headers.get('x-vqd-4')  || ''; // Obter o valor inicial

  // Fazer a primeira chamada
  const chatData = {
    model: model,
    messages: finalMessages // Usando a lista de mensagens formatada corretamente
  };

  const fetchResponse = await fetch('https://duckduckgo.com/duckchat/v1/chat', {
    method: 'POST',
    headers: {
      'x-vqd-4': vqd,
      'Content-Type': 'application/json',
    },
    body: JSON.stringify(chatData),
  });

  // Caso a chamada seja bem-sucedida, atualiza o `x-vqd-4`
  if (fetchResponse.ok) {
    vqd = fetchResponse.headers.get('x-vqd-4') || ''; // Obter novo `x-vqd-4`
  }

  const reader = fetchResponse.body?.getReader();
  const decoder = new TextDecoder();
  let done = false;
  let previousResponse = "";

  while (!done) {
    const { value, done: readerDone } = await reader!.read();
    done = readerDone;
    if (value) {
      const chunk = decoder.decode(value, { stream: true });
      const lines = chunk.split('\n\n');
      for (const line of lines) {
        if (line.startsWith('data: ')) {
          const data = line.substring(6);
          if (data !== '[DONE]') {
            const parsedData = JSON.parse(data);
            const partialResponse = addHelpersDuckDuckGo(parsedData);
            // Adiciona a resposta do modelo como "assistant"
            finalMessages.push({
              role: 'assistant',
              content: partialResponse.result() // Corrigido: adicionando "content"
            });
            yield {
              response: partialResponse,
            };
          }
        }
      }
    }
  }
}

async function makeRequest(url: RequestUrl, body: string, requestOptions?: RequestOptions): Promise<Response> {
  let fetchResponse: Response;
  try {
    fetchResponse = await fetch(url.toURL(), {
      ...buildFetchOptions(requestOptions),
      method: "POST",
      headers: {
        "Content-Type": "application/json",
      },
      body,
    });
    if (!fetchResponse.ok) {
      let message = "";
      try {
        const errResp = await fetchResponse.json();
        message = errResp.error.message;
        if (errResp.error.details) {
          message += ` ${JSON.stringify(errResp.error.details)}`;
        }
      } catch (_e) {
        // ignored
      }
      throw new Error(`[${fetchResponse.status} ${fetchResponse.statusText}] ${message}`);
    }
  } catch (e) {
    console.log(e);
    const err = new GoogleGenerativeAIError(`Error fetching from google -> ${e.message}`);
    err.stack = e.stack;
    throw err;
  }
  return fetchResponse;
}

export class RequestUrl {
  constructor(
    public model: GeminiModel,
    public task: Task,
    public stream: boolean,
    public apiParam: ApiParam,
  ) {}
  toURL(): URL {
    const api_version = API_VERSION.v1beta;
    const url = new URL(`${BASE_URL}/${api_version}/models/${this.model}:${this.task}`);
    url.searchParams.append("key", this.apiParam.apikey);
    if (this.stream) {
      url.searchParams.append("alt", "sse");
    }
    return url;
  }
}

enum Task {
  GENERATE_CONTENT = "generateContent",
  STREAM_GENERATE_CONTENT = "streamGenerateContent",
  COUNT_TOKENS = "countTokens",
  EMBED_CONTENT = "embedContent",
  BATCH_EMBED_CONTENTS = "batchEmbedContents",
}

const BASE_URL = "https://generativelanguage.googleapis.com";

enum API_VERSION {
  v1beta = "v1beta",
  v1 = "v1",
}

/**
 * Generates the request options to be passed to the fetch API.
 * @param requestOptions - The user-defined request options.
 * @returns The generated request options.
 */
function buildFetchOptions(requestOptions?: RequestOptions): RequestInit {
  const fetchOptions = {} as RequestInit;
  if (requestOptions?.timeout) {
    const abortController = new AbortController();
    const signal = abortController.signal;
    setTimeout(() => abortController.abort(), requestOptions.timeout);
    fetchOptions.signal = signal;
  }
  return fetchOptions;
}