Vahe1994 / AQLM

Official Pytorch repository for Extreme Compression of Large Language Models via Additive Quantization https://arxiv.org/pdf/2401.06118.pdf and PV-Tuning: Beyond Straight-Through Estimation for Extreme LLM Compression https://arxiv.org/abs/2405.14852
Apache License 2.0
1.17k stars 177 forks source link

Fine-tune colab example doesn't work #75

Closed jovistos closed 6 months ago

jovistos commented 7 months ago

Fine-tune colab example fails when running trainer.train()

Last cell gives output

max_steps is given, it will override any value given in num_train_epochs /usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn( /usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn(


RuntimeError Traceback (most recent call last)

in <cell line: 24>() 22 ) 23 model.config.use_cache = False # silence the warnings. Please re-enable for inference! ---> 24 trainer.train()

37 frames

/usr/local/lib/python3.10/dist-packages/torch/_ops.py in call(self, *args, *kwargs) 753 # We save the function ptr as the op attribute on 754 # OpOverloadPacket to access it here. --> 755 return self._op(args, **(kwargs or {})) 756 757 # TODO: use this to make a dir

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

BlackSamorez commented 6 months ago

Full error:

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

[<ipython-input-5-388f70847b39>](https://localhost:8080/#) in <cell line: 24>()
     22 )
     23 model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
---> 24 trainer.train()

37 frames

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1857                 hf_hub_utils.enable_progress_bars()
   1858         else:
-> 1859             return inner_training_loop(
   1860                 args=args,
   1861                 resume_from_checkpoint=resume_from_checkpoint,

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2201 
   2202                 with self.accelerator.accumulate(model):
-> 2203                     tr_loss_step = self.training_step(model, inputs)
   2204 
   2205                 if (

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in training_step(self, model, inputs)
   3136 
   3137         with self.compute_loss_context_manager():
-> 3138             loss = self.compute_loss(model, inputs)
   3139 
   3140         if self.args.n_gpu > 1:

[/usr/local/lib/python3.10/dist-packages/transformers/trainer.py](https://localhost:8080/#) in compute_loss(self, model, inputs, return_outputs)
   3159         else:
   3160             labels = None
-> 3161         outputs = model(**inputs)
   3162         # Save past state if it exists
   3163         # TODO: this needs to be fixed and made cleaner later.

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py](https://localhost:8080/#) in forward(*args, **kwargs)
    823 
    824     def forward(*args, **kwargs):
--> 825         return model_forward(*args, **kwargs)
    826 
    827     # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`

[/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    811 
    812     def __call__(self, *args, **kwargs):
--> 813         return convert_to_fp32(self.model_forward(*args, **kwargs))
    814 
    815     def __getstate__(self):

[/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py](https://localhost:8080/#) in decorate_autocast(*args, **kwargs)
     14     def decorate_autocast(*args, **kwargs):
     15         with autocast_instance:
---> 16             return func(*args, **kwargs)
     17 
     18     decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode"  # type: ignore[attr-defined]

[/usr/local/lib/python3.10/dist-packages/peft/peft_model.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1302             with self._enable_peft_forward_hooks(**kwargs):
   1303                 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1304                 return self.base_model(
   1305                     input_ids=input_ids,
   1306                     attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
    177 
    178     def forward(self, *args: Any, **kwargs: Any):
--> 179         return self.model.forward(*args, **kwargs)
    180 
    181     def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)
   1357 
   1358         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1359         outputs = self.model(
   1360             input_ids=input_ids,
   1361             attention_mask=attention_mask,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)
   1214 
   1215             if self.gradient_checkpointing and self.training:
-> 1216                 layer_outputs = self._gradient_checkpointing_func(
   1217                     decoder_layer.__call__,
   1218                     hidden_states,

[/usr/local/lib/python3.10/dist-packages/torch/_compile.py](https://localhost:8080/#) in inner(*args, **kwargs)
     22             import torch._dynamo
     23 
---> 24             return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
     25 
     26         return inner

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    487                 dynamo_config_ctx.__enter__()
    488             try:
--> 489                 return fn(*args, **kwargs)
    490             finally:
    491                 set_eval_frame(prior)

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
     15     @functools.wraps(fn)
     16     def inner(*args, **kwargs):
---> 17         return fn(*args, **kwargs)
     18 
     19     return inner

[/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py](https://localhost:8080/#) in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
    480                 "use_reentrant=False."
    481             )
--> 482         return CheckpointFunction.apply(function, preserve, *args)
    483     else:
    484         gen = _checkpoint_without_reentrant_generator(

[/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py](https://localhost:8080/#) in apply(cls, *args, **kwargs)
    551             # See NOTE: [functorch vjp and autograd interaction]
    552             args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553             return super().apply(*args, **kwargs)  # type: ignore[misc]
    554 
    555         if not is_setup_ctx_defined:

[/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py](https://localhost:8080/#) in forward(ctx, run_function, preserve_rng_state, *args)
    259 
    260         with torch.no_grad():
--> 261             outputs = run_function(*args)
    262         return outputs
    263 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, **kwargs)
    943         residual = hidden_states
    944         hidden_states = self.post_attention_layernorm(hidden_states)
--> 945         hidden_states, router_logits = self.block_sparse_moe(hidden_states)
    946         hidden_states = residual + hidden_states
    947 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, hidden_states)
    873             # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
    874             current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
--> 875             current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
    876 
    877             # However `index_add_` only support torch tensors for indexing so we'll use

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py](https://localhost:8080/#) in forward(self, hidden_states)
    801 
    802     def forward(self, hidden_states):
--> 803         current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
    804         current_hidden_states = self.w2(current_hidden_states)
    805         return current_hidden_states

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

[/usr/local/lib/python3.10/dist-packages/aqlm/inference.py](https://localhost:8080/#) in forward(self, input)
     71 
     72         if self.use_gemv_rule(input):
---> 73             return self.gemv_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)
     74         else:
     75             return self.gemm_op.apply(input, self.codes, self.codebooks, self.scales, self.bias)

[/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py](https://localhost:8080/#) in apply(cls, *args, **kwargs)
    551             # See NOTE: [functorch vjp and autograd interaction]
    552             args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553             return super().apply(*args, **kwargs)  # type: ignore[misc]
    554 
    555         if not is_setup_ctx_defined:

[/usr/local/lib/python3.10/dist-packages/aqlm/inference.py](https://localhost:8080/#) in forward(ctx, input, codes, codebooks, scales, bias)
    114                 bias,
    115             )
--> 116             return forward_pass_kernel(
    117                 input,
    118                 codes,

[/usr/local/lib/python3.10/dist-packages/torch/_ops.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    753         # We save the function ptr as the `op` attribute on
    754         # OpOverloadPacket to access it here.
--> 755         return self._op(*args, **(kwargs or {}))
    756 
    757     # TODO: use this to make a __dir__

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

I'm almost certain the error is raised here:

For some reason checkpointing tries to pass a tensor of size 0 through the model and this particular reshape doesn't handle it. A fix should be straightforward. I'll try and fix it soon.

BlackSamorez commented 6 months ago

Should be fixed in aqlm==1.1.5