genmoai / mochi

The best OSS video generation models
Apache License 2.0
2.09k stars 210 forks source link

torch.OutOfMemoryError: CUDA out of memory #82

Open noskill opened 4 days ago

noskill commented 4 days ago

Hi! I have 4 RTX 3090, but generation fails with out of memory error on 256x512 videos. full stacktrace:

$ time python3 ./demos/cli.py --model_dir weights/ --width=512 --height=256                            
running                                                                                                                                  
Launching with 4 GPUs. If you want to force single GPU mode use CUDA_VISIBLE_DEVICES=0.                                                  
Attention mode: sdpa                                                                                                                     
2024-11-18 07:11:28,684 INFO worker.py:1819 -- Started a local Ray instance.                                                             
(MultiGPUContext pid=1171333) Initializing rank 2/4                                                                                      
(MultiGPUContext pid=1171333) Timing init_process_group                                                                                  
(MultiGPUContext pid=1171345) Timing load_text_encoder                                                                                   
(MultiGPUContext pid=1171345) Timing load_dit                                                                                            
(MultiGPUContext pid=1171345) Initializing rank 4/4 [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 
to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication
 for more options.)                                                                                                                      
(MultiGPUContext pid=1171345) Timing init_process_group [repeated 3x across cluster]                                                     
(MultiGPUContext pid=1171339) Timing load_text_encoder [repeated 3x across cluster]                                                      
(MultiGPUContext pid=1171345) Timing load_vae                                                                                            
(MultiGPUContext pid=1171336) Timing load_dit [repeated 3x across cluster]                                                               
(MultiGPUContext pid=1171345) Stage                   Time(s)    Percent                                                                 
(MultiGPUContext pid=1171345) init_process_group         1.25      2.80%                                                                 
(MultiGPUContext pid=1171345) load_text_encoder          9.70     21.74%                                                                 
(MultiGPUContext pid=1171345) load_dit                  30.95     69.37%                                                                 
(MultiGPUContext pid=1171345) load_vae                   2.72      6.09%                                                                 
(MultiGPUContext pid=1171336) Timing load_vae [repeated 3x across cluster]                                                               
(MultiGPUContext pid=1171339) Stage                   Time(s)    Percent [repeated 3x across cluster]                                    
(MultiGPUContext pid=1171339) init_process_group         1.33      2.99% [repeated 3x across cluster]                                    
(MultiGPUContext pid=1171339) load_text_encoder          9.47     21.21% [repeated 3x across cluster]                                    
(MultiGPUContext pid=1171339) load_dit                  30.95     69.35% [repeated 3x across cluster]                                    
(MultiGPUContext pid=1171339) load_vae                   2.88      6.45% [repeated 3x across cluster]                                    
(pid=1171336) Sampling 0: 100%|███████████████████████████████████████████████████████████████████████| 64.0/64.0 [10:37<00:00, 10.1s/it]
Traceback (most recent call last):             
  File "/home/imgen/projects/genmoai/./demos/cli.py", line 155, in <module>                                                                                                                                            
    generate_cli()                                                                                                                                                                                                     
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/click/core.py", line 1157, in __call__                                                                                                           
    return self.main(*args, **kwargs)                                                                                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                  
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/click/core.py", line 1078, in main                                                                                                               
    rv = self.invoke(ctx)                                                                                                                                                                                              
         ^^^^^^^^^^^^^^^^                                                                                                                                                                                              
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/click/core.py", line 1434, in invoke                                                                                                             
    return ctx.invoke(self.callback, **ctx.params)                                                                                                                                                                     
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                     
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/click/core.py", line 783, in invoke                                                                                                              
    return __callback(*args, **kwargs)                                                                                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                 
  File "/home/imgen/projects/genmoai/./demos/cli.py", line 141, in generate_cli                                                                                                                                        
    output = generate_video(                                                                                                                                                                                           
             ^^^^^^^^^^^^^^^                                                                                                                                                                                           
  File "/home/imgen/projects/genmoai/./demos/cli.py", line 96, in generate_video                                                                                                                                       
    final_frames = pipeline(**args)                                                                                                                                                                                    
                   ^^^^^^^^^^^^^^^^                                                                                                                                                                                    
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/pipelines.py", line 544, in __call__                                                                                         
    return ray.get([ctx.run.remote(fn=sample, **kwargs, show_progress=i == 0) for i, ctx in enumerate(self.ctxs)])[                                                                                                    
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                     
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper                                                                                   
    return fn(*args, **kwargs)                                                                                                                                                                                         
           ^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                         
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper                                                                                          
    return func(*args, **kwargs)                                                                                                                                                                                       
           ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                       
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/ray/_private/worker.py", line 2753, in get                                                                                                       
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)                                                                                                                                     
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                     
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/ray/_private/worker.py", line 904, in get_objects                                                                                                
    raise value.as_instanceof_cause()                                                                                                                                                                                  
ray.exceptions.RayTaskError(OutOfMemoryError): ray::MultiGPUContext.run() (pid=1171345, ip=192.168.1.66, actor_id=c0ae50f4dcb700a6726dfbc301000000, repr=<genmo.mochi_preview.pipelines.MultiGPUContext object at 0x7f5
c3ffe67d0>)                                                                                                                                                                                                            
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                        
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                             
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/pipelines.py", line 499, in run                                                                                              
    return fn(self, **kwargs)                                                                                                                                                                                          
           ^^^^^^^^^^^^^^^^^^                                                                                                                                                                                          
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/pipelines.py", line 541, in sample                                                                                           
    frames = decode_latents(ctx.decoder, latents)                                                                                                                                                                      
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                      
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context                                        
      return func(*args, **kwargs)                                                                                                                                                                              [38/1116]
           ^^^^^^^^^^^^^^^^^^^^^                     
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/vae/models.py", line 1015, in decode_latents                                                                                 
    samples = decoder(z)                             
              ^^^^^^^^^^                             
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                    
    return self._call_impl(*args, **kwargs)                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                            
    return forward_call(*args, **kwargs)                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                   
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/vae/models.py", line 598, in forward                                                                                         
    x = block(x)                                     
        ^^^^^^^^                                     
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                    
    return self._call_impl(*args, **kwargs)                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                            
    return forward_call(*args, **kwargs)                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                   
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/container.py", line 250, in forward                                                                                             
    input = module(input)                            
            ^^^^^^^^^^^^^                            
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                    
    return self._call_impl(*args, **kwargs)                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                            
    return forward_call(*args, **kwargs)                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                   
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/vae/models.py", line 292, in forward                                                                                         
    x = self.stack(x)                                
        ^^^^^^^^^^^^^                                
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                    
    return self._call_impl(*args, **kwargs)                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                            
    return forward_call(*args, **kwargs)                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                   
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/container.py", line 250, in forward                                                                                             
    input = module(input)                            
            ^^^^^^^^^^^^^                            
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl                                                                                    
    return self._call_impl(*args, **kwargs)                                                                
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                            
    return forward_call(*args, **kwargs)                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                   
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/vae/models.py", line 159, in forward                                                                                         
    return super().forward(x)                        
           ^^^^^^^^^^^^^^^^^^                        
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/genmo/mochi_preview/vae/models.py", line 77, in forward                                                                                          
    return super(SafeConv3d, self).forward(input)                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                          
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 725, in forward                                                                                                  
    return self._conv_forward(input, self.weight, self.bias)                                               
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                               
  File "/home/imgen/miniconda3/envs/py32/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 709, in _conv_forward                                                                                            
    return F.conv3d(                                 
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.31 GiB. GPU 0 has a total capacity of 23.68 GiB of which 419.06 MiB is free. Including non-PyTorch memory, this process has 23.27 GiB memory in use. Of
 the allocated memory 21.88 GiB is allocated by PyTorch, and 955.94 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Tru
e to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
(pid=1171336) Sampling 0: 100%|██████████| 64.0/64.0 [10:54<00:00, 10.2s/it]                               

real    11m53.248s                                   
user    0m39.424s                                    
sys     0m26.755s                                                             

nvidis-smi

Mon Nov 18 10:22:35 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:19:00.0 Off |                  N/A |
| 41%   69C    P2            237W /  370W |   21033MiB /  24576MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        Off |   00000000:1A:00.0 Off |                  N/A |
| 75%   67C    P2            254W /  370W |   21032MiB /  24576MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 3090        Off |   00000000:67:00.0 Off |                  N/A |
| 72%   65C    P2            244W /  370W |   21032MiB /  24576MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GeForce RTX 3090        Off |   00000000:68:00.0 Off |                  N/A |
| 74%   66C    P2            254W /  370W |   21032MiB /  24576MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   1201659      C   ray::MultiGPUContext.run                    21024MiB |
|    1   N/A  N/A   1201658      C   ray::MultiGPUContext.run                    21024MiB |
|    2   N/A  N/A   1201660      C   ray::MultiGPUContext.run                    21024MiB |
|    3   N/A  N/A   1201667      C   ray::MultiGPUContext.run                    21024MiB |
+-----------------------------------------------------------------------------------------+
konkura commented 3 days ago

I have the same problem with 6xRTX 4090 trying to run with the default settings.