FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
442 stars 81 forks source link

About the memory error happens when running smax on GPU #45

Closed hsvgbkhgbv closed 11 months ago

hsvgbkhgbv commented 12 months ago

Dear author,

Thank you for your so impressive work. When I try to run smax on GPU, it reports the memory error which seems induced from the huge burdern of buffer storage. Have you tested the size of GPU memory that we need to run the environment?

mttga commented 12 months ago

Could you provide information about your system and which environment/algorithm you were trying to run?

hsvgbkhgbv commented 12 months ago

Hi,

Thank you for your response.

Actually, I am now trying to implement my own algorithm and planning to contribute it to this repo. I have run the implementation of vdn and qmix, and everything runs in normal.

My algorithm has passed the test of MPE, and no errors were raised.

However, when I run my algorithm in smax for 2s3z, the following error was raised:

Error executing job with overrides: ['+alg=shaq_smax', '+env=smax', 'ENTITY=wangjianhong1993', 'env.MAP_NAME=2s3z'] jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/users/jianhong/JaxMARL/baselines/QLearning/shaq.py", line 703, in main outs = jax.block_until_ready(train_vjit(rngs)) jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 12267176800 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 80B constant allocation: 2.3KiB maybe_live_out allocation: 12.91GiB preallocated temp allocation: 11.42GiB preallocated temp fragmentation: 834.23MiB (7.13%) total allocation: 24.34GiB Peak buffers: Buffer 1: Size: 6.09GiB Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 XLA Label: fusion Shape: f32[10,3000,100,545]

Buffer 2:
    Size: 3.22GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/broadcast_in_dim[shape=(10, 1, 99, 1, 1, 10, 1, 5, 32, 1, 545) broadcast_dimensions=(0, 1, 2, 3, 4, 6, 8, 9, 10)]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=171
    XLA Label: fusion
    Shape: f32[10,1,99,1,1,10,1,5,32,1,545]
    ==========================

Buffer 3:
    Size: 1.51GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_1/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=66
    XLA Label: custom-call
    Shape: f32[10,158400,256]
    ==========================

Buffer 4:
    Size: 1.51GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_0/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=66
    XLA Label: custom-call
    Shape: f32[10,158400,256]
    ==========================

Buffer 5:
    Size: 1.27GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
    XLA Label: fusion
    Shape: f32[10,3000,100,114]
    ==========================

Buffer 6:
    Size: 1.27GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
    XLA Label: fusion
    Shape: f32[10,3000,100,114]
    ==========================

Buffer 7:
    Size: 1.27GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
    XLA Label: fusion
    Shape: f32[10,3000,100,114]
    ==========================

Buffer 8:
    Size: 1.27GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
    XLA Label: fusion
    Shape: f32[10,3000,100,114]
    ==========================

Buffer 9:
    Size: 1.27GiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
    XLA Label: fusion
    Shape: f32[10,3000,100,114]
    ==========================

Buffer 10:
    Size: 773.44MiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_0/Dense_1/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=68
    XLA Label: custom-call
    Shape: f32[10,158400,128]
    ==========================

Buffer 11:
    Size: 386.72MiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_2/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=66
    XLA Label: custom-call
    Shape: f32[10,158400,64]
    ==========================

Buffer 12:
    Size: 386.72MiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/dot_general[dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1))) precision=None preferred_element_type=float32]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=190
    XLA Label: fusion
    Shape: f32[1,10,158400,64]
    ==========================

Buffer 13:
    Size: 386.72MiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_1/Dense_1/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=68
    XLA Label: custom-call
    Shape: f32[10,158400,64]
    ==========================

Buffer 14:
    Size: 386.72MiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=178
    XLA Label: custom-call
    Shape: f32[10,158400,64]
    ==========================

Buffer 15:
    Size: 156.25MiB
    Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(AgentRNN)/while[cond_nconsts=0 body_nconsts=13]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=51
    XLA Label: fusion
    Shape: f32[10,100,160,256]
    ==========================

The GPU I used to run the program is Nvidia Volta V100 with 32G memory.

mttga commented 12 months ago

With 2s3z the observation vectors can be large, and if I'm not wrong your maximum number of steps is 100, your buffer 3000 and and you are vmapping the entire algorithm over 10 seeds.

Try to reduce your vmapped experiments or the buffer size, and if that doesn't work reduce also the maximum number of steps.

We were able to run experiments on 2s3z with 2 vmapped seeds and these same parameters on a gpu with 12gb of ram.

hsvgbkhgbv commented 12 months ago

With 2s3z the observation vectors can be large, and if I'm not wrong your maximum number of steps is 100, your buffer 3000 and and you are vmapping the entire algorithm over 10 seeds.

Yes, you are right.

I have addressed the issues when tuning the hyperparameters of my own algo. Thank you for your tips.

By the way, may I ask the detailed procedure of merging my implementation to the repo, e.g., where shall I introduce my algo?

mttga commented 12 months ago

which algorithm are you implementing?

hsvgbkhgbv commented 12 months ago

The algorithm in the following repo: https://github.com/hsvgbkhgbv/shapley-q-learning.