Closed jovistos closed 6 months ago
Full error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-5-388f70847b39>](https://localhost:8080/#) in <cell line: 24>()
22 )
23 model.config.use_cache = False # silence the warnings. Please re-enable for inference!
---> 24 trainer.train()
37 frames
[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1857 hf_hub_utils.enable_progress_bars()
1858 else:
-> 1859 return inner_training_loop(
1860 args=args,
1861 resume_from_checkpoint=resume_from_checkpoint,
[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2201
2202 with self.accelerator.accumulate(model):
-> 2203 tr_loss_step = self.training_step(model, inputs)
2204
2205 if (
[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in training_step(self, model, inputs)
3136
3137 with self.compute_loss_context_manager():
-> 3138 loss = self.compute_loss(model, inputs)
3139
3140 if self.args.n_gpu > 1:
[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in compute_loss(self, model, inputs, return_outputs)
3159 else:
3160 labels = None
-> 3161 outputs = model(**inputs)
3162 # Save past state if it exists
3163 # TODO: this needs to be fixed and made cleaner later.
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
[/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py](https://localhost:8080/#) in forward(*args, **kwargs)
823
824 def forward(*args, **kwargs):
--> 825 return model_forward(*args, **kwargs)
826
827 # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
[/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
811
812 def __call__(self, *args, **kwargs):
--> 813 return convert_to_fp32(self.model_forward(*args, **kwargs))
814
815 def __getstate__(self):
[/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py](https://localhost:8080/#) in decorate_autocast(*args, **kwargs)
14 def decorate_autocast(*args, **kwargs):
15 with autocast_instance:
---> 16 return func(*args, **kwargs)
17
18 decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined]
[/usr/local/lib/python3.10/dist-packages/peft/peft_model.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1302 with self._enable_peft_forward_hooks(**kwargs):
1303 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1304 return self.base_model(
1305 input_ids=input_ids,
1306 attention_mask=attention_mask,
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
[/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
177
178 def forward(self, *args: Any, **kwargs: Any):
--> 179 return self.model.forward(*args, **kwargs)
180
181 def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:
[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)
1357
1358 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1359 outputs = self.model(
1360 input_ids=input_ids,
1361 attention_mask=attention_mask,
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)
1214
1215 if self.gradient_checkpointing and self.training:
-> 1216 layer_outputs = self._gradient_checkpointing_func(
1217 decoder_layer.__call__,
1218 hidden_states,
[/usr/local/lib/python3.10/dist-packages/torch/_compile.py](https://localhost:8080/#) in inner(*args, **kwargs)
22 import torch._dynamo
23
---> 24 return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
25
26 return inner
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
487 dynamo_config_ctx.__enter__()
488 try:
--> 489 return fn(*args, **kwargs)
490 finally:
491 set_eval_frame(prior)
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
15 @functools.wraps(fn)
16 def inner(*args, **kwargs):
---> 17 return fn(*args, **kwargs)
18
19 return inner
[/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py](https://localhost:8080/#) in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
480 "use_reentrant=False."
481 )
--> 482 return CheckpointFunction.apply(function, preserve, *args)
483 else:
484 gen = _checkpoint_without_reentrant_generator(
[/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py](https://localhost:8080/#) in apply(cls, *args, **kwargs)
551 # See NOTE: [functorch vjp and autograd interaction]
552 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553 return super().apply(*args, **kwargs) # type: ignore[misc]
554
555 if not is_setup_ctx_defined:
[/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py](https://localhost:8080/#) in forward(ctx, run_function, preserve_rng_state, *args)
259
260 with torch.no_grad():
--> 261 outputs = run_function(*args)
262 return outputs
263
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, **kwargs)
943 residual = hidden_states
944 hidden_states = self.post_attention_layernorm(hidden_states)
--> 945 hidden_states, router_logits = self.block_sparse_moe(hidden_states)
946 hidden_states = residual + hidden_states
947
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, hidden_states)
873 # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
874 current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
--> 875 current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
876
877 # However `index_add_` only support torch tensors for indexing so we'll use
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, hidden_states)
801
802 def forward(self, hidden_states):
--> 803 current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
804 current_hidden_states = self.w2(current_hidden_states)
805 return current_hidden_states
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
1512
1513 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1521
1522 try:
[/usr/local/lib/python3.10/dist-packages/aqlm/inference.py](https://localhost:8080/#) in forward(self, input)
71
72 if self.use_gemv_rule(input):
---> 73 return self.gemv_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)
74 else:
75 return self.gemm_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)
[/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py](https://localhost:8080/#) in apply(cls, *args, **kwargs)
551 # See NOTE: [functorch vjp and autograd interaction]
552 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553 return super().apply(*args, **kwargs) # type: ignore[misc]
554
555 if not is_setup_ctx_defined:
[/usr/local/lib/python3.10/dist-packages/aqlm/inference.py](https://localhost:8080/#) in forward(ctx, input, codes, codebooks, scales, bias)
114 bias,
115 )
--> 116 return forward_pass_kernel(
117 input,
118 codes,
[/usr/local/lib/python3.10/dist-packages/torch/_ops.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
753 # We save the function ptr as the `op` attribute on
754 # OpOverloadPacket to access it here.
--> 755 return self._op(*args, **(kwargs or {}))
756
757 # TODO: use this to make a __dir__
RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous
I'm almost certain the error is raised here:
For some reason checkpointing tries to pass a tensor of size 0 through the model and this particular reshape doesn't handle it. A fix should be straightforward. I'll try and fix it soon.
Should be fixed in aqlm==1.1.5
Fine-tune colab example fails when running trainer.train()
Last cell gives output
max_steps is given, it will override any value given in num_train_epochs /usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn( /usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn(
RuntimeError Traceback (most recent call last)
37 frames
/usr/local/lib/python3.10/dist-packages/torch/_ops.py in call(self, *args, *kwargs) 753 # We save the function ptr as the
op
attribute on 754 # OpOverloadPacket to access it here. --> 755 return self._op(args, **(kwargs or {})) 756 757 # TODO: use this to make a dirRuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous