texttron / tevatron

Tevatron - A flexible toolkit for neural retrieval research and development.
http://tevatron.ai
Apache License 2.0
435 stars 87 forks source link

long query throw "Dtype object" due to predefined max_length in the batch #95

Open salrowili opened 8 months ago

salrowili commented 8 months ago

I think there is a bug in src/tevatron/driver/jax_train.py this line :

https://github.com/texttron/tevatron/blob/0e939457444f78284ab0471da74a0c74bc76a833/src/tevatron/driver/jax_train.py#L147C43-L147C56

The issue is caused by defining the max_length to 32, assuming all queries will not exceed this length, and that creates a problem when we choose data_args.q_max_len >32. I have a custom dataset with a couple of examples where queries even reach ~ 128 max_length. It would be great if you could fix this issue because the error thrown by python3 is tricky and has no indication that the cause of the problem is due to this line. I have spent two days just to realize that this line is the root of the problem. I fixed the issue by setting the max_length to 128 instead of 32. I think one solution would be just to replace 32 with data_args.q_max_len :

 return dict(tokenizer.pad(qq, max_length=data_args.q_max_len, padding='max_length', return_tensors='np')), dict(
                tokenizer.pad(dd, max_length=data_args.p_max_len, padding='max_length', return_tensors='np'))

Thank you Sultan