jxmorris12 / vec2text

utilities for decoding deep representations (like sentence embeddings) back to text
Other
673 stars 75 forks source link

Why the training is not in an iterative manner as the inference #16

Closed kyriemao closed 8 months ago

kyriemao commented 9 months ago

Hi jxmorris,

Great work. I have a question about your training code.

I've observed that during the training of the corrector model, the input and output consistently follow the pattern (e, e^{0}, x^{0}) and x, where x is a text, e is its ground-truth embedding, and e^{0} and x^{0} are the initial hypothesis embedding and text. Is this correct? However, in the inference stage, you choose to iteratively recover the e into x.

I am just wondering why you do not also train in an iterative manner like training diffusion models. Due to efficiency? In other words, for each training sample (e, x), now you only train with (e, e^{0}, x^{0}, x), but actually we can train with more iterative samples like (e, e^{1}, x^{1}, x), (e, e^{2}, x^{2}, x), etc. I think this will improve the model performance.

Moreover, do you have some intuitions why just training with (e, e^{0}, x^{0}, x) can make it work for iterative inference? I think maybe it is because learning from (e, e^{0}, x^{0}) is harder than learning from (e, e^{n}, x^{n}), so the model can still achieve good performance by just learning from the hardest one?

Thanks!

jxmorris12 commented 8 months ago

this is a great question! We don't train in an iterative manner because it's expensive, and it turns out we don't need it. However, we're definitely leaving performance on the table, and this might be one reason why some of our models train very very slowly.

You should try it; you could just recompute hypotheses each epoch, for example, and add them to the training data. It would be slower (because computing hypotheses takes time) and use more memory (since you're effectively increasing the size of the training set) but would certainly perform better. Training might converge faster too.

As for why it works, I'm not totally sure, but we do speculate in the paper that it's because the relatively large training set covers various levels of 'noise' (cosine distance or BLEU score from ground truth) so our model learns to correct samples at various levels of "difficulty" and therefore can be applied iteratively.