agemagician / Ankh

Ankh: Optimized Protein Language Model
Other
210 stars 19 forks source link

MultiLabelClassification Issue #44

Closed lhallee closed 10 months ago

lhallee commented 12 months ago

Hello @agemagician @hazemessamm ,

In the current multi-label convbert the output is logits that are batch_size, seq_len, emb_dim to compare with labels that are batch_size, emb_dim. F.binary_cross_entropy_with_logits requires the label and pred be the same size, so this will not work if I understand it correctly. One potential solution is to average across the sequence dimension to get batch_size, emb_dim but I'm not sure if this is optimal. Does it work as intended or have I actually found a bug? Best, Logan

liusfore commented 10 months ago

The current multi-label convbert is a Sequence Labeling Task. If our task is sequence labeling (like named entity recognition), we need to ensure that labels is also a sequence. For Sequence Classification Task, if our task is to classify the entire sequence (like sentiment analysis), we might only need a representation of the sequence(often the representation of the first token). In this case, we can take the logits of the first position of the sequence for classification.

Here is my solution.

Take the logits of the first position of the sequence

logits = logits[:, 0, :] # Now the shape is torch.Size([batch_size, label_size]) loss = self._compute_loss(logits, labels)

lhallee commented 10 months ago

ANKH / T5 doesn't have a CLS token, so this could work but it is unlikely to work well with a frozen backbone.