getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.04k stars 64 forks source link

On multi-process gradient sharing #180

Closed aluo-x closed 3 years ago

aluo-x commented 3 years ago

My prior experience with keops has been mostly related to computing point cloud divergences with the GeomLoss wrapper - and for that purpose it works extremely well and scales to pretty large point clouds.

A recent project requires matrix multiplications between dense matrices (essentially a nn.Linear operation) on very large matrices (a 1xMxN weight matrix, a 1xNxH input, where H is around 1e6, M and N are both <1024), and this unfortunately results in memory errors using stock Pytorch operations. I've managed to reduce the memory use somewhat using custom CUDA fused ops, FP16, and gradient checkpointing.

Since the memory expensive parts of the network are nn.Linear, I was considering potentially writing a few layers using keops.

My idea was just to define a custom nn.Module, with a few of the more complex ops using traditional pytorch/CUDA

Class keops_Linear(nn.Module):
    weights go here...

def forward(x):
    weights_k = keops(weights)
    return keops_matmul(weights_k, x)

On this topic, I was wondering about how keops worked with either DataParallel or DistributedDataParallel, and if there were any thoughts on this approach.

Edit: A brief seach of the issues doesn't turn up anyone who has used keops with either pytorch DP or DDP before.

aluo-x commented 3 years ago

Closing this issue. It seems that this is the exact problem tackled by primitives in fairscale/fairseq.