astramind-ai / Mixture-of-depths

Unofficial implementation for the paper "Mixture-of-Depths: Dynamically allocating compute in transformer-based language models"
129 stars 7 forks source link

Training error #9

Open Mi5sssss opened 4 months ago

Mi5sssss commented 4 months ago

Is it possible to have training script example? I encountered tensor mismatch when i increase the training batch more than 1 as following (batch size 4): Original Traceback (most recent call last): File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker output = module(*input, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1164, in forward outputs = self.model( ^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 968, in forward layer_outputs = decoder_layer( ^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/new_mod/Mixture-of-depths/MoD/MoD.py", line 70, in forward block_output = self.block( ^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 713, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( ^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 649, in forward attn_output = torch.nn.functional.scaled_dot_product_attention( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: The expanded size of the tensor (1022) must match the existing size (511) at non-singleton dimension 3. Target sizes: [1, 32, 511, 1022]. Tensor sizes: [1, 1, 511, 511]

If i keep training batch as 1, i have some tuple index error: `Original Traceback (most recent call last): File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker output = module(*input, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 1164, in forward outputs = self.model( ^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xier2/miniconda3/envs/mod/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 981, in forward next_decoder_cache = layer_outputs[2 if output_attentions else 1]


IndexError: tuple index out of range`
JAVI897 commented 3 months ago

Have you been able to solve this? I am encountering the same error

pharaohcaptain commented 3 months ago

The issue of "tuple index out of range" may arise because the embedding multiplied by the router's weights can lead to some values exceeding the representational range of float16, resulting in inf. A direct solution is to normalize the weights:

def forward(self, x):
    original_type = x.dtype
    self.weight_predictor.to(torch.float32)
    weights = self.weight_predictor(x.to(self.weight_predictor.weight.dtype)).squeeze(
        -1
    )  # [batch_size, seq_len]
    weights = weights / torch.sum(weights,dim=-1,keepdim=True)
    return weights.to(original_type)