HEmile / KENN-PyTorch

PyTorch implementation of the KENN model
BSD 3-Clause "New" or "Revised" License
8 stars 2 forks source link

OOM in Knowledge Enhancer on GPU #4

Open LuisaWerner opened 2 years ago

LuisaWerner commented 2 years ago

Hi there,

I experience OOM issues when I execute the code on a GPU. You have already mentioned this as a #TODO in the file Clause_Enhancer.py.
I solved this issue by using torch_scatter.scatter_add() instead of torch.stack(scatter_deltas_list).sum(dim=0) to build the sum over the changes by the clause enhancers. Is there any particular reason why you are not using a scatter operation? Did you experience it to be slower in terms of runtimes or does it not fulfil the same functionality as the solution based on torch.stack(...)?

Also, I got the problem that not all tensors were on the same device when executing on GPU. I solved this by registering the relevant tensors as buffers in ClaueEnhancer.py:

self.register_buffer('signs', signs)
self.register_buffer('gather_literal_indices', gather_literal_indices)
self.register_buffer('scatter_literal_indices', scatter_literal_indices)

Thank you, Luisa