shtoshni / fast-coref

Code for the CRAC 2021 paper "On Generalization in Coreference Resolution" (Best short paper award)
33 stars 13 forks source link

A clarification in the code-base #19

Open KawshikManikantan opened 1 year ago

KawshikManikantan commented 1 year ago

In the file: src/model/mention_proposal/utils.py Line 9: sort_scores = ment_starts + 1e-5 * ment_starts Should it be + 1e-5 * ment_ends - To have shorter mentions (starting at the same index) first? Or is it just some coding convenience?

shtoshni commented 1 year ago

@KawshikManikantan Thanks for raising this issue. You're right, I have fixed it in the latest commit. Let me know if it looks alright.

KawshikManikantan commented 1 year ago

@shtoshni Thanks for the immediate reply. Although the above change should work. I still find some discrepancies. For example in the figure below. The first tensor is the ment_starts[sorted_indices] and the second is ment_ends[sorted_indices]. We can see that 306 306 pair in the first tensor has a correct 306 308 pair while it is not the case with 960 960 pair which has a corresponding 961 960. Can you think of any reason why this happens? image

KawshikManikantan commented 1 year ago

@shtoshni It looks like an error due to floating point inaccuracies. This works as far as I can see: sort_scores = ment_starts.to(torch.float64) + 1e-5 * ment_ends.to(torch.float64)

shtoshni commented 1 year ago

@KawshikManikantan: Must be due to difference in torch version. I'm not running into any such issue. Regarding the sorting, yeah that looks weird tbh, no idea why that's the case.

KawshikManikantan commented 1 year ago
ment_starts = torch.tensor([960.0,960.0]) 
ment_ends = torch.tensor([961.0,960.0]) 
torch.set_printoptions(precision=10) 
sort_scores = ment_starts + 0.00001 * ment_ends 
print(sort_scores,sort_scores.dtype)

Output: tensor([960.0095825195, 960.0095825195]) torch.float32

ment_starts = torch.tensor([960.0,960.0],dtype=torch.float64)
ment_ends = torch.tensor([961.0,960.0],dtype=torch.float64)
torch.set_printoptions(precision=10)
sort_scores = ment_starts + 0.00001 * ment_ends
print(sort_scores,sort_scores.dtype)

Output: tensor([960.0096100000, 960.0096000000], dtype=torch.float64) torch.float64

This is the result I get in multiple environments and the reason for my sorting troubles. It resolves upon typecasting though.