samuelbroscheit / open_knowledge_graph_embeddings

Code to train open knowledge graph embeddings and to create a benchmark for open link prediction.
MIT License
25 stars 7 forks source link

Error for small batch size #5

Open FauzanFarooqui opened 1 year ago

FauzanFarooqui commented 1 year ago

I have been working on OIE and OKGs for the past few months and found this paper highly relevant and impactful for the field.

I trained and evaluated LSTMComplEx for the 4096 batch size. As the Lookup embeddings take a lot of GPU, I reduced the batch size all the way down to eight (and the word vector size to 384).

I was trying the default Lookup option, but it wouldn't run because of the following error at line 240 in openkge/model.py image

After investing quite some time, I narrowed down the core of the error to be the corner case when a relation batch was sampled only once i.e slot_item.shape = [1,1] This means that when it's squeezed, the tensor becomes "shapeless", losing even the single dimension and retaining only a singleton element. This makes it return a tensor of word vector size [384], rather than the intended shape [1,384]. image

That's why, back in the ComplEx scorer, line 190 would cause the batch_sz to be set to 384 instead of 1. Hence, the error comes as shown in the first picture.

To solve this, I modified the squeezing for this special case: image

This seems to have fixed the error for me.

Noticing that I had changed the batch size between my run on LSTM and Lookup, I checked whether this occurs for low batch sizes on LSTM too. Similarly, an error that is mostly caused by the same corner case comes up when running the LSTM embedding too. The above detailed issue comes from batch size 4, 8 and 16 (works fine for 32+).

This means that the error comes due to low batch size, regardless of the embedding method - what could be the cause (and is my solution appropriate)? Though not a major concern at my end, this could help someone who could "just fit" the model on small batch sizes for their available memory.

PS - The Warning on the torch.squeeze documentation echoes what could be happening here.