LargeWorldModel / LWM

Apache License 2.0
7.1k stars 549 forks source link

out of memory error #37

Closed lucky2046 closed 7 months ago

lucky2046 commented 7 months ago

bash scripts/run_vision_chat.sh removed --mesh_dim param model is LWM-Chat-32K-Jax out of memory error, how to solve it

my card is nvidia 2080 super 8G

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1708500656.672727   10871 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
I0221 15:30:57.202437 140383335174272 xla_bridge.py:513] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0221 15:30:57.202921 140383335174272 xla_bridge.py:513] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-02-21 15:36:18.340692: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.00GiB (rounded to 2147483648)requested by op 
2024-02-21 15:36:18.340908: W external/tsl/tsl/framework/bfc_allocator.cc:497] *________**********************************************************************_____________________
2024-02-21 15:36:18.340944: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2644] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.00GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.00GiB
     preallocated temp allocation:         0B
                 total allocation:    3.00GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 2.00GiB
                Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
                XLA Label: fusion
                Shape: f32[32,4096,4096]
                ==========================

        Buffer 2:
                Size: 1.00GiB
                Entry Parameter Subshape: bf16[32,4096,4096]
                ==========================

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 254, in <module>
    run(main)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 249, in main
    sampler = Sampler()
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 51, in __init__
    self._load_model()
  File "/mnt/data/test/LWM/lwm/vision_chat.py", line 199, in _load_model
    self.params = tree_apply(shard_fns, self.params)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in tree_apply
    return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/jax/_src/tree_util.py", line 244, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/jax_utils.py", line 148, in <lambda>
    return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
  File "/home/test/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/distributed.py", line 95, in shard_fn
    return jax_shard_function(tensor).block_until_ready()
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2147483648 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.00GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.00GiB
     preallocated temp allocation:         0B
                 total allocation:    3.00GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
        Buffer 1:
                Size: 2.00GiB
                Operator: op_name="pjit(to_dtype)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/mnt/data/test/LWM/lwm/vision_chat.py" source_line=199
                XLA Label: fusion
                Shape: f32[32,4096,4096]
                ==========================

        Buffer 2:
                Size: 1.00GiB
                Entry Parameter Subshape: bf16[32,4096,4096]
                ==========================

I0000 00:00:1708500978.900009   10871 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
(lwm) test@test-3:/mnt/data/test/LWM$ nvidia-smi
Wed Feb 21 15:47:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2080 S...    Off| 00000000:01:00.0 Off |                  N/A |
|  0%   40C    P0               23W / 250W|      0MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
jackyin68 commented 7 months ago

can you share me your modified requirements.txt?

lucky2046 commented 7 months ago

can you share me your modified requirements.txt?

I did not modify requirements. txt, I modified run_vision_chat.sh for your reference

#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

# MODEL_NAME='LWM-Chat-1M-Jax'
# MODEL_NAME='LWM-Chat-128K-Jax'
MODEL_NAME='LWM-Chat-32K-Jax'

export llama_tokenizer_path="/mnt/data/test/LWM/models/${MODEL_NAME}/tokenizer.model"
export vqgan_checkpoint="/mnt/data/t'e's't/LWM/models/${MODEL_NAME}/vqgan"
export lwm_checkpoint="/mnt/data/test/LWM/models/${MODEL_NAME}/params"
export input_file="/mnt/data/test/2020-07-30_pose_test_006.mp4"

python3 -u -m lwm.vision_chat \
    --prompt="What is the video about?" \
    --input_file="$input_file" \
    --vqgan_checkpoint="$vqgan_checkpoint" \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --max_n_frames=8 \
    --update_llama_config="dict(sample_mode='text',theta=50000000,max_sequence_length=131072,use_flash_attention=False,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,remat_attention='',scan_mlp=False,scan_mlp_chunk_size=2048,remat_mlp='',remat_block='',scan_layers=True)" \
    --load_checkpoint="params::$lwm_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
2>&1 | tee ~/output.log
read
wilson1yan commented 7 months ago

I don't think your GPU has enough memory, as by itself a 7B model with fp32 would be 28GB.