jxmorris12 / vec2text

utilities for decoding deep representations (like sentence embeddings) back to text
Other
686 stars 76 forks source link

use llama 7b as corrector #63

Closed karanjotsv closed 2 weeks ago

karanjotsv commented 3 weeks ago

Hi, I have been trying to use the framework with llama 7b as a corrector but keep getting a size mismatch error, can you please guide?

jxmorris12 commented 3 weeks ago

Some questions:

  1. Are you trying to do embedding inversion or LM/logit inversion?
  2. Did you already train a first-step model and generate hypotheses? You need this for correction.
  3. Can you post your code and the stack trace?
karanjotsv commented 3 weeks ago

Hey, I am trying to do embedding inversion: motivated from "Text Embeddings Reveal (Almost) As Much As Text". I have not trained a zero-step model nor the hypothesis part, honestly I am quite confused by the documentation. My code is straight forward from this repo's README.md, as of now I have tried to invert some strings using the pre-trained model associated with the card: jxm/gtrnq32__correct.

I tried to use jxm/t5-basellama-7bone-million-instructions__correct hoping it might work for llama but it throws size mismatch error for weights. How should I proceed if I want to use llama 7b as a corrector?

karanjotsv commented 2 weeks ago

Hey @jxmorris12, can you let me know how should I proceed please?

jxmorris12 commented 2 weeks ago

Hi! That paper is about text embeddings models, which produce single-vector outputs, typically obtained from pooled hidden states. We train an encoder-decoder model (based on T5) that predicts text from embeddings. LLAMA is not a text embedding model or an encoder-decoder model, so it can't be straightforwardly substituted.

That said, you could still use a decoder-only model like LLAMA for embedding-to-text tasks -- in fact, in the paper, we tried GPT-2 for this, although it performed poorly. To do that, you could write your own class that inputs embeddings as the first few tokens and decodes the text from there. It's a non-trivial amount of effort, though, and you certainly won't just be able to load the LLAMA weights into T5 as you seem to be trying.

tiandong1234 commented 2 weeks ago

Hey, I am trying to do embedding inversion: motivated from "Text Embeddings Reveal (Almost) As Much As Text". I have not trained a zero-step model nor the hypothesis part, honestly I am quite confused by the documentation. My code is straight forward from this repo's README.md, as of now I have tried to invert some strings using the pre-trained model associated with the card: jxm/gtrnq32__correct.

I tried to use jxm/t5-basellama-7bone-million-instructions__correct hoping it might work for llama but it throws size mismatch error for weights. How should I proceed if I want to use llama 7b as a corrector?

I have the same issue as you. Maybe you need to use the class CorrectorEncoderFromLogitsModel to load the corrector model "t5-basellama-7bone-million-instructions__correct".

like vec2text.models.CorrectorEncoderFromLogitsModel.from_pretrained(path)

jxmorris12 commented 2 weeks ago

Sorry, what are you trying to do @tiandong1234? Do you want to invert language model logits or text embeddings?

tiandong1234 commented 2 weeks ago

Sorry, what are you trying to do @tiandong1234? Do you want to invert language model logits or text embeddings?

I'm trying to invert langurage model logits and text embeddings using llama7b model for which I tried following code inversion_model=vec2text.models.InversionFromLogitsEmbModel.from_pretrained("jxm/t5-base__llama-7b__one-million-instructions__emb").to(device) corrector_model = vec2text.models.CorrectorEncoderFromLogitsModel.from_pretrained("jxm/t5-base___llama-7b___one-million-instructions__correct") corrector = vec2text.load_corrector(inversion_model,corrector_model)

while loading "corrector_model", I have encountered with the issue: 'No such file or directory: '/home/jxm3/research/retrieval/inversion/llama_unigram.pt'' that is due to 29 line in file 'vec2text/models/corrector_encoder_from_logits.py' which is 'self.unigram = torch.load("/home/jxm3/research/retrieval/inversion/llama_unigram.pt")'

that's the problem I encountered in the new opening issue

jxmorris12 commented 2 weeks ago

Hi @tiandong1234 – I've removed this constraint. Can you try running the code from the latest branch?