There is a mismatch between the last layer's inputs's dim and the last layer itself.
The best solution seems to let the user pass what kind of reduction they use. Common choices: first (the <CLS> token in BERT), last (in causal LMs), average (https://arxiv.org/abs/2402.05015). We can use enum for this.
Say, you have an LLM with a regression head on top. Then this code
outputs
There is a mismatch between the last layer's inputs's dim and the last layer itself.
The best solution seems to let the user pass what kind of reduction they use. Common choices:
first
(the<CLS>
token in BERT),last
(in causal LMs),average
(https://arxiv.org/abs/2402.05015). We can use enum for this.