Closed hsvgbkhgbv closed 11 months ago
Could you provide information about your system and which environment/algorithm you were trying to run?
Hi,
Thank you for your response.
Actually, I am now trying to implement my own algorithm and planning to contribute it to this repo. I have run the implementation of vdn and qmix, and everything runs in normal.
My algorithm has passed the test of MPE, and no errors were raised.
However, when I run my algorithm in smax for 2s3z, the following error was raised:
Error executing job with overrides: ['+alg=shaq_smax', '+env=smax', 'ENTITY=wangjianhong1993', 'env.MAP_NAME=2s3z'] jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Buffer 2:
Size: 3.22GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/broadcast_in_dim[shape=(10, 1, 99, 1, 1, 10, 1, 5, 32, 1, 545) broadcast_dimensions=(0, 1, 2, 3, 4, 6, 8, 9, 10)]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=171
XLA Label: fusion
Shape: f32[10,1,99,1,1,10,1,5,32,1,545]
==========================
Buffer 3:
Size: 1.51GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_1/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=66
XLA Label: custom-call
Shape: f32[10,158400,256]
==========================
Buffer 4:
Size: 1.51GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_0/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=66
XLA Label: custom-call
Shape: f32[10,158400,256]
==========================
Buffer 5:
Size: 1.27GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
XLA Label: fusion
Shape: f32[10,3000,100,114]
==========================
Buffer 6:
Size: 1.27GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
XLA Label: fusion
Shape: f32[10,3000,100,114]
==========================
Buffer 7:
Size: 1.27GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
XLA Label: fusion
Shape: f32[10,3000,100,114]
==========================
Buffer 8:
Size: 1.27GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
XLA Label: fusion
Shape: f32[10,3000,100,114]
==========================
Buffer 9:
Size: 1.27GiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jit(add)/while/body/scatter[dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1, 2), inserted_window_dims=(1,), scatter_dims_to_operand_dims=(1,)) indices_are_sorted=True unique_indices=True mode=GatherScatterMode.FILL_OR_DROP update_jaxpr=None update_consts=()]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/buffers/utils.py" source_line=42 deduplicated_name="fusion.765"
XLA Label: fusion
Shape: f32[10,3000,100,114]
==========================
Buffer 10:
Size: 773.44MiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_0/Dense_1/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=68
XLA Label: custom-call
Shape: f32[10,158400,128]
==========================
Buffer 11:
Size: 386.72MiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_2/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=66
XLA Label: custom-call
Shape: f32[10,158400,64]
==========================
Buffer 12:
Size: 386.72MiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/dot_general[dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1))) precision=None preferred_element_type=float32]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=190
XLA Label: fusion
Shape: f32[1,10,158400,64]
==========================
Buffer 13:
Size: 386.72MiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/HyperNetwork_1/Dense_1/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=68
XLA Label: custom-call
Shape: f32[10,158400,64]
==========================
Buffer 14:
Size: 386.72MiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(SHAQMixer)/AlphaEstimate_0/Dense_0/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=178
XLA Label: custom-call
Shape: f32[10,158400,64]
==========================
Buffer 15:
Size: 156.25MiB
Operator: op_name="jit(train)/jit(main)/vmap(while)/body/jvp(AgentRNN)/while[cond_nconsts=0 body_nconsts=13]" source_file="/users/jianhong/JaxMARL/baselines/QLearning/shaq.py" source_line=51
XLA Label: fusion
Shape: f32[10,100,160,256]
==========================
The GPU I used to run the program is Nvidia Volta V100 with 32G memory.
With 2s3z the observation vectors can be large, and if I'm not wrong your maximum number of steps is 100, your buffer 3000 and and you are vmapping the entire algorithm over 10 seeds.
Try to reduce your vmapped experiments or the buffer size, and if that doesn't work reduce also the maximum number of steps.
We were able to run experiments on 2s3z with 2 vmapped seeds and these same parameters on a gpu with 12gb of ram.
With 2s3z the observation vectors can be large, and if I'm not wrong your maximum number of steps is 100, your buffer 3000 and and you are vmapping the entire algorithm over 10 seeds.
Yes, you are right.
I have addressed the issues when tuning the hyperparameters of my own algo. Thank you for your tips.
By the way, may I ask the detailed procedure of merging my implementation to the repo, e.g., where shall I introduce my algo?
which algorithm are you implementing?
The algorithm in the following repo: https://github.com/hsvgbkhgbv/shapley-q-learning.
Dear author,
Thank you for your so impressive work. When I try to run smax on GPU, it reports the memory error which seems induced from the huge burdern of buffer storage. Have you tested the size of GPU memory that we need to run the environment?