advimman / lama

🦙 LaMa Image Inpainting, Resolution-robust Large Mask Inpainting with Fourier Convolutions, WACV 2022
https://advimman.github.io/lama-project/
Apache License 2.0
8.12k stars 861 forks source link

Seeing NaN for some images when exporting as JIT Model #249

Open josephcatrambone-crucible opened 1 year ago

josephcatrambone-crucible commented 1 year ago

Hi, all!

First: thank you so much for producing LAMA. It's a very impressive piece of tech and there's a beautiful elegance and simplicity to it that I love.

For my question: seemingly at random I will encounter an image which, when submitted for inpainting, will result in all NaNs. I can't even make it happen consistently for the same image. Of key importance: I'm using the JIT exported model.

tensor(nan, device='cuda:0', grad_fn=<SumBackward0>)
tensor(nan, device='cuda:0', grad_fn=<MinBackward1>)
tensor(nan, device='cuda:0', grad_fn=<MaxBackward1>)

I can't tell precisely where in the network the NaN is popping up because, again, this is the JIT version and I can't trace the individual operations. Even if I could, checking the network after each inference step would probably wear me down pretty quickly. I have checked to make sure the inputs are real numbers, and they're doing okay.

The operation generally works, but sometimes it fails, and I'd like to figure out why, so if this is a known issue or there are ideas from the dev team about where to start, I'm all ears.

This may be a factor, but I don't think I'm using half-precision (unless that's what's being done internally): https://github.com/pytorch/pytorch/issues/33485

EDIT: To rule out the half-precision issue I printed out the sums of the inputs. It looks like a smaller mask area is more likely to contribute to being a NaN output, but it's not perfect:

Image Sum Mask Sum Success
125007.125 891 0
124303.4609 139271 1
124303.4609 776 0
124303.4609 0 0
124303.4609 22037 0
124303.4609 194264 1
124303.4609 20340 1
124303.4609 33889 1
124303.4609 8179 0

EDIT 2: I happened to notice while disabling half (via model = model.eval().to(torch.float32).to(device)) that this was printed in the predict() method. This suggests that (a) half is being used and (b) I failed to disable half precision calculations with this approach. ".../lib/python3.10/site-packages/torch/nn/modules/module.py:1501: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343967769/work/aten/src/ATen/EmptyTensor.cpp:31.) return forward_call(*args, **kwargs)"

EDIT 3: While I can't disable automatic mixed precision for the whole control flow, wrapping the predict call like so seems to have alleviated the problem. I'm going to retest with a bunch of images to see if the problem re-emerges:

with torch.cuda.amp.autocast(enabled=False):  # Added.
    pred = model(img_mat.to(device), mask_mat.to(device))