lean-dojo / ReProver

Retrieval-Augmented Theorem Provers for Lean
https://leandojo.org
MIT License
208 stars 44 forks source link

Retriever embeddings depend on padding #34

Closed darabos closed 7 months ago

darabos commented 9 months ago

I think because T5 is not a causal model, the padding tokens also influence the hidden state of the earlier tokens. Thus the whole premise embedding. In datamodule.py the tokenizer is used with padding='longest', so the length of padding depends on how long the longest premise in the batch is.

This could be a minor effect, but it seems fairly significant to me. Check out this simple experiment: https://gist.github.com/darabos/3084b8571b17875076a63045496e2bd5

I took the example from the README and ran it with different batch sizes. We get 4 different rankings. Maybe it's worth trying padding='max_length'?

yangky11 commented 9 months ago

Just an update, I'm a bit swarmed recently (due to upcoming NeurIPS) and will only be able to take a look around Christmas holidays.

yangky11 commented 8 months ago

I was able to reproduce the problem. Switching to padding='max_length' does make the output more deterministic. However, it also becomes significantly slower, as the length of the longest sequence in the batch is often much shorter than max_length.

Purewhite2019 commented 5 months ago

It seems that the difference is due to not passing the attention mask to the model. You can resolve this by modifying your test code as follows:

@torch.no_grad()
def encode(s: Union[str, List[str]], batch_size: int) -> torch.Tensor:
    batches = [s[i:i + batch_size] for i in range(0, len(s), batch_size)]
    batch_results = []
    for batch in batches:
      tokenized_s = tokenizer(
          batch,
          return_tensors='pt',
          padding='longest',
          truncation=True, max_length=2300,
        )
      hidden_state = model(tokenized_s.input_ids, tokenized_s.attention_mask).last_hidden_state # here
      lens = tokenized_s.attention_mask.sum(dim=1)
      features = (hidden_state * tokenized_s.attention_mask.unsqueeze(2)).sum(dim=1) / lens.unsqueeze(1)
      batch_results.append(features)
    return torch.cat(batch_results)
darabos commented 5 months ago

It seems that the difference is due to not passing the attention mask to the model. You can resolve this by modifying your test code as follows:

Fantastic, thanks! This is actually how retrieval/model.py does it, so everything is good. 👍