huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.31k stars 26.85k forks source link

Using BERTModel for learning a siamese encoder #11094

Closed drunkinlove closed 3 years ago

drunkinlove commented 3 years ago

Hi! This is more of a question than a bug report. Can I use BERTModel without any modifications to train a BERT-based siamese encoder?

(Not sure if this really is a BERT-specific question, but I will tag @LysandreJik just in case)

This is how my training step looks like:

optimizer.zero_grad()

outputs_a = model(input_ids_a, attention_mask=attention_mask_a)
outputs_b = model(input_ids_b, attention_mask=attention_mask_b)
a = torch.mean(outputs_a['last_hidden_state'], axis=1)
b = torch.mean(outputs_b['last_hidden_state'], axis=1)
cossim_normalized = (cossim(a, b) + 1) / 2

loss = bcentropy(cossim_normalized, labels)
loss.backward()
optimizer.step()

Should this work? Most other examples of siamese models in PyTorch simply modify the forward pass to include the second input, but I don't see why gradients shouldn't accumulate properly in my case.

NielsRogge commented 3 years ago

Hi,

Can you please ask this question on the forum rather than here? For example, this comment might help you already.

The authors of HuggingFace like to keep Github issues for bugs/feature requests.

Thank you!

drunkinlove commented 3 years ago

Thank you, closing this.