jxmorris12 / vec2text

utilities for decoding deep representations (like sentence embeddings) back to text
Other
673 stars 75 forks source link

Error with vec2text.invert_embeddings when set num_steps parameter #54

Closed Maitouer closed 1 month ago

Maitouer commented 1 month ago

Hi, @jxmorris12 Thanks so much for this insightful work! After trying the demo, I got the following error when I set the "num_steps" parameter in vec2text.invert_embeddings:

################## Error Message ################## Traceback (most recent call last): File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/trainers/corrector.py", line 252, in generate hypothesis_input_ids = inputs["hypothesis_input_ids"] KeyError: 'hypothesis_input_ids'

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 478, in call result = fn(*args, **kwargs) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/utils/utils.py", line 204, in get_embeddings_openai_vanilla_multithread client = OpenAI() File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/openai/_client.py", line 104, in init raise OpenAIError( openai.OpenAIError: The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 478, in call result = fn(*args, *kwargs) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/utils/utils.py", line 266, in embed_api embeddings = get_embeddings_func( File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 336, in wrapped_f return copy(f, args, **kw) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 475, in call do = self.iter(retry_state=retry_state) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 376, in iter result = action(retry_state) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 419, in exc_check raise retry_exc from fut.exception() tenacity.RetryError: RetryError[<Future at 0x7f4a70665330 state=finished raised OpenAIError>]

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/zhdd/home/jqzhang/DoGE/src/main.py", line 52, in vec2text.invert_embeddings( File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/api.py", line 106, in invert_embeddings regenerated = corrector.generate( File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/trainers/corrector.py", line 261, in generate ) = self._get_hypothesis_uncached(inputs=inputs) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/trainers/corrector.py", line 624, in _get_hypothesis_uncached hypothesis_embedding = self.embed_generated_hypothesis( File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/trainers/corrector.py", line 586, in embed_generated_hypothesis return self.get_frozen_embeddings( File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/trainers/corrector.py", line 567, in get_frozen_embeddings frozen_embeddings = self.inversion_trainer.call_embedding_model( File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/vec2text/models/inversion.py", line 195, in call_embedding_model embeddings = embed_api( File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 336, in wrapped_f return copy(f, *args, **kw) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 475, in call do = self.iter(retry_state=retry_state) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 376, in iter result = action(retry_state) File "/data/jqzhang/Miniconda3/envs/dge/lib/python3.10/site-packages/tenacity/init.py", line 419, in exc_check raise retry_exc from fut.exception() tenacity.RetryError: RetryError[<Future at 0x7f4a70664be0 state=finished raised RetryError>]

################## My Code ##################

import os
import time

import openai
import torch
import vec2text

corrector = vec2text.load_pretrained_corrector("text-embedding-ada-002")

def get_embeddings_openai(
    text_list,
    model="text-embedding-ada-002",
    max_retries=20,
    backoff_factor=1.5,
) -> torch.Tensor:
    client = openai.OpenAI(
        api_key="########",
        base_url="#########",
    )

    retries = 0
    while retries < max_retries:
        try:
            response = client.embeddings.create(
                input=text_list,
                model=model,
                encoding_format="float",  # override default base64 encoding...
            )
            outputs = []
            outputs.extend(e.embedding for e in response.data)
            return torch.tensor(outputs)
        except Exception:
            wait = backoff_factor * (2**retries)
            time.sleep(wait)
            retries += 1
    raise Exception("API call failed after maximum number of retries")

embeddings = get_embeddings_openai(
    [
        "The user watched movies: Bohemian Rhapsody, Nightmare Before Christmas, The Peppermint.",
        "The user watched books: Oregon Atlas and Gazetteer, California Atlas & Gazetteer, The Complete Pebble Mosaic Handbook.",
    ]
)

print(
    vec2text.invert_embeddings(
        embeddings=embeddings.cuda(),
        corrector=corrector,
        num_steps=5,
    )
)
jxmorris12 commented 1 month ago

@Maitouer were you just rate-limited by OpenAI? Why'd you close the issue?

Maitouer commented 1 month ago

@Maitouer were you just rate-limited by OpenAI? Why'd you close the issue?

Yes, it was entirely my own issue. The problem was resolved when I set the api_key and base_url using os.environ. Now, I can successfully use the contributions of this work to solve my own problems. Thank you very much!

jxmorris12 commented 1 month ago

awesome!