NVlabs / nvdiffrast

Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering
Other
1.35k stars 144 forks source link

Possible memory leak when using nn.DataParallel #23

Closed Turlan closed 3 years ago

Turlan commented 3 years ago

Hi, when I use your code to implement multi-gpu training with the provided rasterization, the gpu memory keeps increasing.

I first define a list of instances of the class RasterizeGLContext for each gpu in the init func of pytorch nn.Module class. During forward, I choose the RasterizeGLContext instance according to the current device id. The gpu memory keeps increasing when I use gpus >= 2.

I don't know whether I wrongly use the code or there exists some bugs in your implementation. If possible, could you provide some sample codes for multi-gpu training? Thanks!

s-laine commented 3 years ago

Hi Turlan! Are you certain that you don't create more and more GL contexts as the training progresses? I believe Torch's nn.DataParallel respawns the worker threads between every epoch, and even though this should destroy the contexts and release GPU memory, perhaps something there doesn't work as it should. If you call set_log_level(0) you can see when GL contexts are created and destroyed, which may help in pinpointing the problem.

I have only tested nn.DataParallel with custom loaders that avoid the epoch changes and worker thread restarts altogether. I haven't tested with stock nn.DataParallel, but I would expect things to work there too — there are no known bugs. Would you be able to provide a minimal repro? Also, are you on Linux or Windows?

yangjiaolong commented 3 years ago

Hi s-laine, this is collaborator of Turlan. We are also testing nn.DataParallel with a simple custom loader (actually just repetitively sending the same data). There's definitely no epoch issue.

Just tried set_log_level(0) and found GL contexts were correctly created twice for two devices cuda:0 and cuda:1. No destroy info was printed before the program crashed.

[I glutil.cpp:322] Creating GL context for Cuda device 0
[I glutil.cpp:370] EGL 5.1 OpenGL context created (disp: 0x0000555c0b12f4c0, ctx: 0x0000555c0b138f91)
[I rasterize.cpp:91] OpenGL version reported as 4.6
[I glutil.cpp:322] Creating GL context for Cuda device 1
[I glutil.cpp:370] EGL 5.1 OpenGL context created (disp: 0x0000555c0d6e93e0, ctx: 0x0000555c0d6bcc21)
[I rasterize.cpp:91] OpenGL version reported as 4.6
[I rasterize.cpp:332] Increasing position buffer size to 393216 float32
[I rasterize.cpp:343] Increasing triangle buffer size to 98304 int32
[I rasterize.cpp:368] Increasing frame buffer size to (width, height, depth) = (1280, 1280, 1)
[I rasterize.cpp:394] Increasing range array size to 64 elements
[I rasterize.cpp:332] Increasing position buffer size to 393216 float32
[I rasterize.cpp:343] Increasing triangle buffer size to 98304 int32
[I rasterize.cpp:368] Increasing frame buffer size to (width, height, depth) = (1280, 1280, 1)
[I rasterize.cpp:394] Increasing range array size to 64 elements
0 0.177
1 0.117
2 0.058
3 0.070
4 0.108
5 0.109
6 0.104
7 0.126
.....

Interestingly, there's no error printed out before the crash. The program simply stopped at a certain iteration number far smaller that what I set, and as Turlan mentioned, the GPU memory kept increasing when the program was running.

btw, we were testing on Linux.

s-laine commented 3 years ago

Based on the log it looks like RasterizeGLContexts are not the issue here. Two GL contexts are created and the internal buffers don't grow after the initial allocations.

I would next try to check that references to old tensors aren't mistakenly kept around so that Torch cannot deallocate them. It wouldn't be a surprise that this goes differently between 1 and 2 GPUs, because with multiple GPUs you need to aggregate gradients and share them between GPUs, and maybe this leaves some references lying around.

There is also the possibility that despite cleaning up stale references, there remain circular references between the old tensors so that the objects stay in memory. In this case, calling Python's garbage collector may be required to clean them up quickly enough to avoid running out of GPU memory. To see if that is the case, you can try calling gc.collect() after each training iteration and see if the GPU memory consumption stops growing. Nvdiffrast should not create such circular references internally, so if it does then that's a bug on our side.

yangjiaolong commented 3 years ago

Thank you Samuli for the prompt reply. Just tried calling gc.collect() after each iteration and it did not resolve the issue. GPU memory was still growing; see the log below. I also printed out the GPU memory usage obtained by calling torch.cuda APIs and querying nvidia-smi. You can see that memory usage increases by ~2MB after each iteration according to nvidia-smi. It keeps increasing until program crashes at around the 120th iter.

[I glutil.cpp:322] Creating GL context for Cuda device 0
[I glutil.cpp:370] EGL 5.1 OpenGL context created (disp: 0x0000556f7a8291c0, ctx: 0x0000556f7a832c01)
[I rasterize.cpp:91] OpenGL version reported as 4.6
[I glutil.cpp:322] Creating GL context for Cuda device 1
[I glutil.cpp:370] EGL 5.1 OpenGL context created (disp: 0x0000556f7afbe1f0, ctx: 0x0000556f7af91e11)
[I rasterize.cpp:91] OpenGL version reported as 4.6
[I rasterize.cpp:332] Increasing position buffer size to 3145728 float32
[I rasterize.cpp:343] Increasing triangle buffer size to 786432 int32
[I rasterize.cpp:368] Increasing frame buffer size to (width, height, depth) = (1280, 1280, 10)
[I rasterize.cpp:394] Increasing range array size to 64 elements
[I rasterize.cpp:332] Increasing position buffer size to 3145728 float32
[I rasterize.cpp:343] Increasing triangle buffer size to 786432 int32
[I rasterize.cpp:368] Increasing frame buffer size to (width, height, depth) = (1280, 1280, 10)
[I rasterize.cpp:394] Increasing range array size to 64 elements
gc.get_count() (195, 10, 8)
gc.collect() 0
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2781
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1480
Iter 0 Time 0.515
gc.get_count() (51, 1, 0)
gc.collect() 204
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2783
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1482
Iter 1 Time 0.504
gc.get_count() (53, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2785
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1484
Iter 2 Time 0.639
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2786
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1485
Iter 3 Time 0.473
gc.get_count() (51, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2788
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1487
Iter 4 Time 0.459
gc.get_count() (58, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2790
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1489
Iter 5 Time 0.486
gc.get_count() (54, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2791
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1490
Iter 6 Time 0.598
gc.get_count() (58, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2793
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1492
Iter 7 Time 0.625
gc.get_count() (50, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2795
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1494
Iter 8 Time 0.575
gc.get_count() (50, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2796
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1495
Iter 9 Time 0.544
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2798
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1497
Iter 10 Time 0.616
gc.get_count() (51, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2800
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1499
Iter 11 Time 0.598
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2801
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1500
Iter 12 Time 0.507
gc.get_count() (50, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2803
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1502
Iter 13 Time 0.567
gc.get_count() (59, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2805
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1504
Iter 14 Time 0.496
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2807
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1505
Iter 15 Time 0.571
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2808
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1507
Iter 16 Time 0.537
gc.get_count() (51, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2810
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1509
Iter 17 Time 0.542
gc.get_count() (51, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2812
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1511
Iter 18 Time 0.534
gc.get_count() (51, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2813
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1512
Iter 19 Time 0.521
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2815
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1514
Iter 20 Time 0.527
gc.get_count() (54, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2817
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1516
Iter 21 Time 0.588
gc.get_count() (51, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2818
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1517
Iter 22 Time 0.570
gc.get_count() (51, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2820
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1519
Iter 23 Time 0.529
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2822
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1521
Iter 24 Time 0.488
gc.get_count() (50, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2823
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1522
Iter 25 Time 0.633
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2825
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1524
Iter 26 Time 0.571
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2827
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1526
Iter 27 Time 0.603
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2829
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1527
Iter 28 Time 0.551
gc.get_count() (52, 1, 0)
gc.collect() 196
gc.get_count() (4, 0, 0)
torch.cuda.max_memory_reserved(0) 759169024
torch.cuda.max_memory_allocated(0) 685223936
torch.cuda.memory_allocated(0) 318986240
torch.cuda.memory_reserved(0) 759169024
nvidia-smi GPU-0 memory (MB) 2831
torch.cuda.max_memory_reserved(1) 608174080
torch.cuda.max_memory_allocated(1) 530863104
torch.cuda.memory_allocated(1) 10997248
torch.cuda.memory_reserved(1) 608174080
nvidia-smi GPU-1 memory (MB) 1529
Iter 29 Time 0.531
...

If I just use one GPU by setting one visible GPU id (but still using nn.DataParallel), everything goes perfectly normal. Memory is stable and the program doesn't crash.

s-laine commented 3 years ago

Very interesting! Torch reports constant memory usage which rules out any problems with stale tensor references, and points to something leaking memory on the OpenGL side. However, the ~1.7MB per iteration doesn't quite match any of the buffers allocated in the rasterization op, and I also have no idea why using two GPUs would leak memory if one GPU doesn't.

Contrary to my first comment, I realized that I haven't actually tried using nn.DataParallel but only nn.DistributedDataParallel that spawns a separate process per GPU. This way each child process uses only one GPU and that may be why in my tests I haven't encountered this problem. Perhaps this is something you could also consider as a workaround?

In any case, I would highly appreciate a minimal repro that could be used to root-cause the issue.

yangjiaolong commented 3 years ago

Tried nn.DistributedDataParallel and there's no memory leak issue! We will use it for now as a workaround. Thanks Samuli.

Regading nn.DataParallel, I'll try to make a minimal repro later to help identify the potential issue.

dariopavllo commented 3 years ago

Hi all,

I was also interested in trying out nvdiffrast with nn.DataParallel, and encountered a similar problem. In my case, the GPU memory usage increases slowly after each rasterization call, and the program eventually crashes with a segmentation fault error (not memory exhaustion). This usually happens after 100-200 iterations. I also noticed that GPU utilization is extremely low in this setting (around 1-2%) compared to a single-GPU setting, pointing out to some locking mechanism or internal re-initialization. I eventually managed to find a (hacky) workaround for the issue, so I'm sharing it in case someone finds it useful.

Firstly, I tried to debug the issue and found out that it is limited to the forward pass of dr.rasterize. The following (minimal) code reproduces the issue:

import torch
import torch.nn as nn
import nvdiffrast.torch as dr
import nvdiffrast
from tqdm import tqdm
nvdiffrast.torch.set_log_level(0)

class Rasterizer(nn.Module):
    def __init__(self):
        super().__init__()
        self.ctx = {}

    def forward(self, vertices, faces):
        if vertices.device not in self.ctx:
            self.ctx[vertices.device] = dr.RasterizeGLContext(output_db=False, device=vertices.device)
            print('Created GL context for device', vertices.device)
        ctx = self.ctx[vertices.device]

        rast_out, _ = dr.rasterize(ctx, vertices, faces, resolution=(256, 256))
        return rast_out

gpu_ids = [0, 1]
rasterizer = nn.DataParallel(Rasterizer(), gpu_ids)

bs = 2
nt = 1
vertices = torch.randn(bs, nt*3, 4).cuda()
faces = torch.arange(nt*3).view(-1, 3)
faces_rep = faces.repeat(len(gpu_ids), 1).int().cuda()

with torch.no_grad():
    for i in tqdm(range(100000)):
        rasterizer(vertices, faces_rep)

In the code above, the GL context is lazily initialized in forward, meaning that the initialization will happen in a dedicated thread that will be destroyed, but I also tried to initialize everything in the constructor (i.e. main thread) and nothing changes. Using the GL context in manual mode doesn't help either. Note that I disabled gradient computation, so the issue is definitely not related to the backward pass.

In particular, debugging the crash with cuda-gdb produces the following stack trace:

#0  0x00007fff4f47f3c0 in NvGlEglGetFunctions () from /lib/x86_64-linux-gnu/libnvidia-eglcore.so.450.102.04
#1  0x00007fff4f322259 in NvGlEglGetFunctions () from /lib/x86_64-linux-gnu/libnvidia-eglcore.so.450.102.04
#2  0x00007fff4f3222f0 in NvGlEglGetFunctions () from /lib/x86_64-linux-gnu/libnvidia-eglcore.so.450.102.04
#3  0x00007fff4f2e4ba4 in NvGlEglGetFunctions () from /lib/x86_64-linux-gnu/libnvidia-eglcore.so.450.102.04
#4  0x00007fff4f2aba4a in NvGlEglApiInit () from /lib/x86_64-linux-gnu/libnvidia-eglcore.so.450.102.04
#5  0x00007fff80d01cb2 in NvEglDevtoolsQuery () from /lib/x86_64-linux-gnu/libEGL_nvidia.so.0
#6  0x00007fff8ccda489 in cuEGLApiInit () from /lib/x86_64-linux-gnu/libcuda.so.1
#7  0x00007fff8cbefe96 in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#8  0x00007fff8cdc8e73 in cuGraphicsMapResources () from /lib/x86_64-linux-gnu/libcuda.so.1
#9  0x00007fffde64bc9e in __cudaPopCallConfiguration () from /local/home/user/miniconda3/envs/pt16/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#10 0x00007fffde6939c9 in cudaGraphicsMapResources () from /local/home/user/miniconda3/envs/pt16/lib/python3.7/site-packages/torch/lib/../../../../libcudart.so.10.2
#11 0x00007fff8615c47c in rasterizeRender(int, RasterizeGLState&, CUstream_st*, float const*, int, int, int const*, int, int const*, int, int, int, int) ()
   from /local/home/user/.cache/torch_extensions/nvdiffrast_plugin/nvdiffrast_plugin.so
#12 0x00007fff861b4f92 in rasterize_fwd(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int) ()
   from /local/home/user/.cache/torch_extensions/nvdiffrast_plugin/nvdiffrast_plugin.so
#13 0x00007fff861a37d0 in std::tuple<at::Tensor, at::Tensor> pybind11::detail::argument_loader<RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int>::call_impl<std::tuple<at::Tensor, at::Tensor>, std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), 0ul, 1ul, 2ul, 3ul, 4ul, 5ul, pybind11::detail::void_type>(std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul, 4ul, 5ul>, pybind11::detail::void_type&&) () from /local/home/user/.cache/torch_extensions/nvdiffrast_plugin/nvdiffrast_plugin.so
#14 0x00007fff8619bb17 in std::enable_if<!std::is_void<std::tuple<at::Tensor, at::Tensor> >::value, std::tuple<at::Tensor, at::Tensor> >::type pybind11::detail::argument_loader<RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int>::call<std::tuple<at::Tensor, at::Tensor>, pybind11::detail::void_type, std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int)>(std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int)) && () from /local/home/user/.cache/torch_extensions/nvdiffrast_plugin/nvdiffrast_plugin.so
#15 0x00007fff861911d0 in void pybind11::cpp_function::initialize<std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), std::tuple<at::Tensor, at::Tensor>, RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int, pybind11::name, pybind11::scope, pybind11::sibling, char [21]>(std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), std::tuple<at::Tensor, at::Tensor> (*)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [21])::{lambda(pybind11::detail::function_call&)#3}::operator()(pybind11::detail::function_call&) const () from /local/home/user/.cache/torch_extensions/nvdiffrast_plugin/nvdiffrast_plugin.so
#16 0x00007fff86191626 in void pybind11::cpp_function::initialize<std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), std::tuple<at::Tensor, at::Tensor>, RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int, pybind11::name, pybind11::scope, pybind11::sibling, char [21]>(std::tuple<at::Tensor, at::Tensor> (*&)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), std::tuple<at::Tensor, at::Tensor> (*)(RasterizeGLStateWrapper&, at::Tensor, at::Tensor, std::tuple<int, int>, at::Tensor, int), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [21])::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () from /local/home/user/.cache/torch_extensions/nvdiffrast_plugin/nvdiffrast_plugin.so
#17 0x00007fff8617cc0c in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /local/home/user/.cache/torch_extensions/nvdiffrast_plugin/nvdiffrast_plugin.so
#18 0x00005555556b9914 in _PyMethodDef_RawFastCallKeywords () at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:693
#19 0x00005555556b9a31 in _PyCFunction_FastCallKeywords (func=0x7fff8e6ddbe0, args=<optimized out>, nargs=<optimized out>, kwnames=<optimized out>)
    at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:732
#20 0x0000555555725ebd in call_function (kwnames=0x0, oparg=6, pp_stack=<synthetic pointer>) at /tmp/build/80754af9/python_1598874792229/work/Python/ceval.c:4568
#21 _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1598874792229/work/Python/ceval.c:3093
#22 0x000055555566985b in function_code_fastcall (globals=<optimized out>, nargs=8, args=<optimized out>, co=0x7fff8f8c18a0)
    at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:283
#23 _PyFunction_FastCallDict () at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:322
#24 0x00007fffdb8436f5 in THPFunction_apply(_object*, _object*) () from /local/home/user/miniconda3/envs/pt16/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#25 0x00005555556b9884 in _PyMethodDef_RawFastCallKeywords () at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:697
#26 0x00005555556b9a31 in _PyCFunction_FastCallKeywords (func=0x7fff8e9432d0, args=<optimized out>, nargs=<optimized out>, kwnames=<optimized out>)
    at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:732
#27 0x0000555555725ebd in call_function (kwnames=0x0, oparg=7, pp_stack=<synthetic pointer>) at /tmp/build/80754af9/python_1598874792229/work/Python/ceval.c:4568
#28 _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1598874792229/work/Python/ceval.c:3093
#29 0x0000555555668829 in _PyEval_EvalCodeWithName () at /tmp/build/80754af9/python_1598874792229/work/Python/ceval.c:3930
#30 0x00005555556b9107 in _PyFunction_FastCallKeywords () at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:433
#31 0x0000555555722585 in call_function (kwnames=0x7ffff761c9d0, oparg=<optimized out>, pp_stack=<synthetic pointer>) at /tmp/build/80754af9/python_1598874792229/work/Python/ceval.c:4616
#32 _PyEval_EvalFrameDefault () at /tmp/build/80754af9/python_1598874792229/work/Python/ceval.c:3139
#33 0x000055555566985b in function_code_fastcall (globals=<optimized out>, nargs=3, args=<optimized out>, co=0x7ffff7600150)
    at /tmp/build/80754af9/python_1598874792229/work/Objects/call.c:283

In my case, the crash happens in the forward pass of rasterize (specifically, when calling cuGraphicsMapResources). The function seems to call some internal GL initialization code, which in theory should already be initialized. This made me think that the problem might be related to a threading issue -- perhaps the GL library is not thread-safe, or it contains an initialization routine that is called each time a new thread is created.

For verification, the following code runs just fine:

bs = 16
nt = 512
vertices1 = torch.randn(bs, nt*3, 4).to('cuda:0')
faces1 = torch.arange(nt*3).view(-1, 3).int().to('cuda:0')
ctx1 = dr.RasterizeGLContext(output_db=False, device='cuda:0')

vertices2 = torch.randn(bs, nt*3, 4).to('cuda:1')
faces2 = torch.arange(nt*3).view(-1, 3).int().to('cuda:1')
ctx2 = dr.RasterizeGLContext(output_db=False, device='cuda:1')

with torch.no_grad():
    for i in tqdm(range(1000000)):
        rast_out1, _ = dr.rasterize(ctx1, vertices1, faces1, resolution=(1024, 1024))
        rast_out2, _ = dr.rasterize(ctx2, vertices2, faces2, resolution=(1024, 1024))
        torch.cuda.synchronize('cuda:0')
        torch.cuda.synchronize('cuda:1')

Here, everything is called from the main thread and GPU utilization is 100% for both GPUs, which is good (the calls are asynchronous). No memory leaks or crashes. However, being able to use nn.DataParallel is still more convenient because it suits most workflows and some operations in PyTorch are blocking.

In the end, I came up with the following workaround/hack: I use a rudimental thread pool for the calls to dr.rasterize. Each GPU has a thread that is never destroyed; forward passes to dr.rasterize are dispatched to the dedicated thread and then returned to the calling thread. Probably not very efficient, but it works and I can manage to saturate GPU utilization. A better option is to modify nn.DataParallel to use thread pools instead of spawning a new thread each time forward is called.

Sharing the code for completeness, but don't rely too much on it (I haven't tested edge cases):

class Dispatcher:
    def __init__(self, gpu_ids):
        self.threads = {}
        self.events = {}
        self.funcs = {}
        self.return_events = {}
        self.return_values = {}

        for gpu_id in gpu_ids:
            device = torch.device(gpu_id)
            self.events[device] = threading.Event()
            self.return_events[device] = threading.Event()
            self.threads[device] = threading.Thread(target=Dispatcher.worker, args=(self, device,), daemon=True)
            self.threads[device].start()

    @staticmethod
    def worker(self, device):
        ctx = dr.RasterizeGLContext(output_db=False, device=device)
        while True:
            self.events[device].wait()
            assert device not in self.return_values
            self.return_values[device] = self.funcs[device](ctx)
            del self.funcs[device]
            self.events[device].clear()
            self.return_events[device].set()

    def __call__(self, device, func):
        assert device not in self.funcs
        self.funcs[device] = func
        self.events[device].set()
        self.return_events[device].wait()
        ret_val = self.return_values[device]
        del self.return_values[device]
        self.return_events[device].clear()
        return ret_val

gpu_ids = [0, 1]
dispatcher = Dispatcher(gpu_ids)

class Rasterizer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, vertices, faces):
        rast_out, _ = dispatcher(vertices.device, lambda ctx: dr.rasterize(ctx, vertices, faces, resolution=(256, 256)))
        return rast_out

rasterizer = nn.DataParallel(Rasterizer(), gpu_ids)
bs = 1024*2
nt = 512
vertices = torch.randn(bs, nt*3, 4).cuda()
faces = torch.arange(nt*3).view(-1, 3)
faces_rep = faces.repeat(len(gpu_ids), 1).int().cuda()
with torch.no_grad():
    for i in tqdm(range(100000)):
        rast_out = rasterizer(vertices, faces_rep)

In the example above, the contexts are created in the dedicated threads, but creating them in the main thread works fine as well. Looks like the real initialization is done during the first call to dr.rasterize.

Hope this helps, and if the authors have some insight into this, I would like to hear your opinion!

s-laine commented 3 years ago

Hi @dariopavllo, big thanks for posting your analysis and insights here! I hadn't examined the internals of nn.DataParallel before but you're absolutely right: it spawns new threads every time a parallelized operation is invoked (source).

This means that the OpenGL contexts need to be constantly migrated between threads. This appears to be an expensive and error-prone operation, leading to low GPU utilization and memory leaks. My guess is that some driver-level buffers or data structures of the context, probably related to Cuda-OpenGL interoperability, are thread-specific and they aren't deallocated until the context is destroyed entirely. This would lead to the observed accumulation of crud. There is a function called eglReleaseThread() but I cannot see how it could help with the performance issues, so even if it fixed the memory leaks this wouldn't be an acceptable solution.

It makes perfect sense that your workaround solves these issues, as each context is always used from a single, dedicated thread. This was the only usage pattern that I had in mind when developing and testing the code, as I didn't use Torch's nn.Module abstraction and hence no nn.DataParallel either. If we one day include nn.Module-compatible classes in nvdiffrast, we'll probably have to do something similar to this. Again, big thanks for your post!

s-laine commented 3 years ago

This is now addressed in the documentation since v0.2.6, with a link to this issue for details. Closing.