getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Segfault in backward with some shapes. #299

Closed flbbb closed 1 year ago

flbbb commented 1 year ago

Hi, I'm running the code below on a A5000 25Go. When the batch dimension B is on the last dimension (running with the script as is) it works.

When running with --swap it fails with segfault after the backward call.

Edit: Also tested on V100 32g and same issue.

import torch
from argparse import ArgumentParser
from pykeops.torch import LazyTensor

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("-L", "--length", type=int, default=4096)
    parser.add_argument("-H", "--d_model", type=int, default=720)
    parser.add_argument("-N", "--n_dim", type=int, default=256)
    parser.add_argument("-B", "--batch_size", type=int, default=32)
    parser.add_argument("--swap", action="store_true")

    args = parser.parse_args()
    L = args.length
    H = args.d_model
    N = args.n_dim
    B = args.batch_size

    if args.swap:
        # Batch first.
        A = torch.randn(1, H, N, 1, 1, device="cuda")
        u = torch.randn(B, H, 1, L, 1, requires_grad=True, device="cuda")
        p = torch.randn(1, 1, 1, L, 1, device="cuda")

        A_i = LazyTensor(A)
        u_j = LazyTensor(u)
        p_j = LazyTensor(p)
        k_ij = u_j * (p_j * A_i).exp()

        res = k_ij.sum(dim=3)
        res.mean().backward()

    else:
       # Batch last.
        A = torch.randn(H, N, 1, 1, device="cuda")
        u = torch.randn(H, 1, L, B, requires_grad=True, device="cuda")
        p = torch.randn(1, 1, L, 1, device="cuda")

        A_i = LazyTensor(A)
        u_j = LazyTensor(u)
        p_j = LazyTensor(p)
        k_ij = u_j * (p_j * A_i).exp()

        res = k_ij.sum(dim=2)
        res.mean().backward()
bcharlier commented 1 year ago

Hi @flbbb ,

this is the standard behavior. Batch dimensions are, in the LazyTensor framework and by convention (which is quite common though), the first dimensions... If your data comes with a "batch last" setting, you should manually swap dimensions and make the data continuous (there can be a copy here).

NB : check the actual entry of the res tensor in the "last batch" version of your code. My guess is that it does not give what you plan to compute...

Best,