I experience OOM issues when I execute the code on a GPU. You have already mentioned this as a #TODO in the file Clause_Enhancer.py.
I solved this issue by using torch_scatter.scatter_add() instead of torch.stack(scatter_deltas_list).sum(dim=0) to build the sum over the changes by the clause enhancers. Is there any particular reason why you are not using a scatter operation? Did you experience it to be slower in terms of runtimes or does it not fulfil the same functionality as the solution based on torch.stack(...)?
Also, I got the problem that not all tensors were on the same device when executing on GPU. I solved this by registering the relevant tensors as buffers in ClaueEnhancer.py:
Hi there,
I experience OOM issues when I execute the code on a GPU. You have already mentioned this as a #TODO in the file
Clause_Enhancer.py
.I solved this issue by using
torch_scatter.scatter_add()
instead oftorch.stack(scatter_deltas_list).sum(dim=0)
to build the sum over the changes by the clause enhancers. Is there any particular reason why you are not using a scatter operation? Did you experience it to be slower in terms of runtimes or does it not fulfil the same functionality as the solution based ontorch.stack(...)
?Also, I got the problem that not all tensors were on the same device when executing on GPU. I solved this by registering the relevant tensors as buffers in
ClaueEnhancer.py
:Thank you, Luisa