triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.56k stars 1.67k forks source link

issue while using Triton to finetune a 4-bit model on multiple GPUs #2404

Open 01miaom opened 1 year ago

01miaom commented 1 year ago

Hi, I encountered an issue while using Triton for LoRa finetuning of mpt-storywriter-4bit. The problem occurs when the program reaches the following line of code:

self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)

The error message I'm getting is:

ValueError: Pointer argument (at 1) cannot be accessed from Triton (cpu tensor?)

This issue only occurs when fine-tuning on multiple GPUs with a model that has undergone compression using the GPTQ algorithm. Fine-tuning the same compressed model on a single GPU works without any problems. Additionally, I've successfully fine-tuned an uncompressed 8-bit model on multiple GPUs without encountering a similar issue.

Environment

Python 3.9.18
Triton 2.0.0
PyTorch 2.0.1+cu118
GPU A800

Error Traceback

 /home/miao/anaconda3/envs/mpt3/bin/mpttune:33 in <module>                     │
│                                                                              │
│   30                                                                         │
│   31 if __name__ == '__main__':                                              │
│   32 │   sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0])   │
│ ❱ 33 │   sys.exit(load_entry_point('mpttune==0.1.0', 'console_scripts', 'mpt │
│   34                                                                         │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/run.py:87 in main                                              │
│                                                                              │
│   84                                                                         │
│   85 def main():                                                             │
│   86 │   args = get_args()                                                   │
│ ❱ 87 │   args.func(args)                                                     │
│   88                                                                         │
│   89                                                                         │
│   90 if __name__ == '__main__':                                              │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/finetune.py:162 in finetune                                    │
│                                                                              │
│   159 │   │   │   │   set_peft_model_state_dict(model, state_dict_peft)      │
│   160 │   │   │   │   trainer.train(tune_config.resume_checkpoint)           │
│   161 │   │   │   else:                                                      │
│ ❱ 162 │   │   │   │   trainer.train()                                        │
│   163 │   │                                                                  │
│   164 │   │   # Restore old model state dict                                 │
│   165 │   │   model.state_dict = old_state_dict                              │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/transformers/train │
│ er.py:1664 in train                                                          │
│                                                                              │
│   1661 │   │   inner_training_loop = find_executable_batch_size(             │
│   1662 │   │   │   self._inner_training_loop, self._train_batch_size, args.a │
│   1663 │   │   )                                                             │
│ ❱ 1664 │   │   return inner_training_loop(                                   │
│   1665 │   │   │   args=args,                                                │
│   1666 │   │   │   resume_from_checkpoint=resume_from_checkpoint,            │
│   1667 │   │   │   trial=trial,                                              │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/transformers/train │
│ er.py:1940 in _inner_training_loop                                           │
│                                                                              │
│   1937 │   │   │   │   │   with model.no_sync():                             │
│   1938 │   │   │   │   │   │   tr_loss_step = self.training_step(model, inpu │
│   1939 │   │   │   │   else:                                                 │
│ ❱ 1940 │   │   │   │   │   tr_loss_step = self.training_step(model, inputs)  │
│   1941 │   │   │   │                                                         │
│   1942 │   │   │   │   if (                                                  │
│   1943 │   │   │   │   │   args.logging_nan_inf_filter                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/transformers/train │
│ er.py:2735 in training_step                                                  │
│                                                                              │
│   2732 │   │   │   return loss_mb.reduce_mean().detach().to(self.args.device │
│   2733 │   │                                                                 │
│   2734 │   │   with self.compute_loss_context_manager():                     │
│ ❱ 2735 │   │   │   loss = self.compute_loss(model, inputs)                   │
│   2736 │   │                                                                 │
│   2737 │   │   if self.args.n_gpu > 1:                                       │
│   2738 │   │   │   loss = loss.mean()  # mean() to average on multi-gpu para │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/transformers/train │
│ er.py:2767 in compute_loss                                                   │
│                                                                              │
│   2764 │   │   │   labels = inputs.pop("labels")                             │
│   2765 │   │   else:                                                         │
│   2766 │   │   │   labels = None                                             │
│ ❱ 2767 │   │   outputs = model(**inputs)                                     │
│   2768 │   │   # Save past state if it exists                                │
│   2769 │   │   # TODO: this needs to be fixed and made cleaner later.        │
│   2770 │   │   if self.args.past_index >= 0:                                 │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/nn/modules/m │
│ odule.py:1501 in _call_impl                                                  │
│                                                                              │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                      │
│   1502 │   │   # Do not call functions when jit is used                      │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1504 │   │   backward_pre_hooks = []                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/peft/peft_model.py │
│ :678 in forward                                                              │
│                                                                              │
│    675 │   ):                                                                │
│    676 │   │   peft_config = self.active_peft_config                         │
│    677 │   │   if not isinstance(peft_config, PromptLearningConfig):         │
│ ❱  678 │   │   │   return self.base_model(                                   │
│    679 │   │   │   │   input_ids=input_ids,                                  │
│    680 │   │   │   │   attention_mask=attention_mask,                        │
│    681 │   │   │   │   inputs_embeds=inputs_embeds,                          │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/nn/modules/m │
│ odule.py:1501 in _call_impl                                                  │
│                                                                              │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                      │
│   1502 │   │   # Do not call functions when jit is used                      │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1504 │   │   backward_pre_hooks = []                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/accelerate/hooks.p │
│ y:165 in new_forward                                                         │
│                                                                              │
│   162 │   │   │   with torch.no_grad():                                      │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                  │
│   164 │   │   else:                                                          │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                      │
│   166 │   │   return module._hf_hook.post_forward(module, output)            │
│   167 │                                                                      │
│   168 │   module.forward = new_forward                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/model/mpt/model.py:864 in forward                              │
│                                                                              │
│    861 │   │   return_dict = return_dict if return_dict is not None else sel │
│    862 │   │                                                                 │
│    863 │   │   # decoder outputs consists of (dec_features, layer_state, dec │
│ ❱  864 │   │   outputs = self.transformer(                                   │
│    865 │   │   │   input_ids=input_ids,                                      │
│    866 │   │   │   attention_mask=attention_mask,                            │
│    867 │   │   │   position_ids=position_ids,                                │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/nn/modules/m │
│ odule.py:1501 in _call_impl                                                  │
│                                                                              │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                      │
│   1502 │   │   # Do not call functions when jit is used                      │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1504 │   │   backward_pre_hooks = []                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/model/mpt/model.py:772 in forward                              │
│                                                                              │
│    769 │   │   │   │   │   None,                                             │
│    770 │   │   │   │   )                                                     │
│    771 │   │   │   else:                                                     │
│ ❱  772 │   │   │   │   layer_outputs = decoder_layer(                        │
│    773 │   │   │   │   │   hidden_states,                                    │
│    774 │   │   │   │   │   attention_mask=attention_mask,                    │
│    775 │   │   │   │   │   attn_bias=attn_bias,                              │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/nn/modules/m │
│ odule.py:1501 in _call_impl                                                  │
│                                                                              │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                      │
│   1502 │   │   # Do not call functions when jit is used                      │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1504 │   │   backward_pre_hooks = []                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/accelerate/hooks.p │
│ y:165 in new_forward                                                         │
│                                                                              │
│   162 │   │   │   with torch.no_grad():                                      │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                  │
│   164 │   │   else:                                                          │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                      │
│   166 │   │   return module._hf_hook.post_forward(module, output)            │
│   167 │                                                                      │
│   168 │   module.forward = new_forward                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/model/mpt/model.py:443 in forward                              │
│                                                                              │
│    440 │   │                                                                 │
│    441 │   │   a = self.norm_1(hidden_states)                                │
│    442 │   │                                                                 │
│ ❱  443 │   │   (b, self_attn_weights, present_key_value) = self.attn(        │
│    444 │   │   │   hidden_states=a,                                          │
│    445 │   │   │   attention_mask=attention_mask,                            │
│    446 │   │   │   attn_bias=attn_bias,                                      │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/nn/modules/m │
│ odule.py:1501 in _call_impl                                                  │
│                                                                              │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                      │
│   1502 │   │   # Do not call functions when jit is used                      │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1504 │   │   backward_pre_hooks = []                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/accelerate/hooks.p │
│ y:165 in new_forward                                                         │
│                                                                              │
│   162 │   │   │   with torch.no_grad():                                      │
│   163 │   │   │   │   output = old_forward(*args, **kwargs)                  │
│   164 │   │   else:                                                          │
│ ❱ 165 │   │   │   output = old_forward(*args, **kwargs)                      │
│   166 │   │   return module._hf_hook.post_forward(module, output)            │
│   167 │                                                                      │
│   168 │   module.forward = new_forward                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/model/mpt/model.py:373 in forward                              │
│                                                                              │
│    370 │   │   │   use_cache: bool = False,                                  │
│    371 │   ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[t │
│    372 │   │                                                                 │
│ ❱  373 │   │   qkv = self.Wqkv(hidden_states)                                │
│    374 │   │                                                                 │
│    375 │   │   if self.clip_qkv:                                             │
│    376 │   │   │   qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)         │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/nn/modules/m │
│ odule.py:1501 in _call_impl                                                  │
│                                                                              │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                      │
│   1502 │   │   # Do not call functions when jit is used                      │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1504 │   │   backward_pre_hooks = []                                       │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/model/lora.py:59 in forward                                    │
│                                                                              │
│    56 │   │   self.active_adapter = adapter_name                             │
│    57 │                                                                      │
│    58 │   def forward(self, x: torch.Tensor):                                │
│ ❱  59 │   │   result = self.quant_instance.forward(x)                        │
│    60 │   │                                                                  │
│    61 │   │   if self.disable_adapters or self.active_adapter not in self.lo │
│    62 │   │   │   return result                                              │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/backend/triton/quantlinear.py:13 in forward                    │
│                                                                              │
│   10 │                                                                       │
│   11 │   def forward(self, x):                                               │
│   12 │   │   if torch.is_grad_enabled():                                     │
│ ❱ 13 │   │   │   out = AutogradMatmul.apply(                                 │
│   14 │   │   │   │   x, self.qweight, self.scales,                           │
│   15 │   │   │   │   self.qzeros, self.g_idx, self.bits, self.maxq)          │
│   16 │   │   else:                                                           │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/autograd/fun │
│ ction.py:506 in apply                                                        │
│                                                                              │
│   503 │   │   if not torch._C._are_functorch_transforms_active():            │
│   504 │   │   │   # See NOTE: [functorch vjp and autograd interaction]       │
│   505 │   │   │   args = _functorch.utils.unwrap_dead_wrappers(args)         │
│ ❱ 506 │   │   │   return super().apply(*args, **kwargs)  # type: ignore[misc │
│   507 │   │                                                                  │
│   508 │   │   if cls.setup_context == _SingleLevelFunction.setup_context:    │
│   509 │   │   │   raise RuntimeError(                                        │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/torch/cuda/amp/aut │
│ ocast_mode.py:104 in decorate_fwd                                            │
│                                                                              │
│   101 │   │   │   args[0]._fwd_used_autocast = False                         │
│   102 │   │   │   if autocast_context:                                       │
│   103 │   │   │   │   with autocast(enabled=False):                          │
│ ❱ 104 │   │   │   │   │   return fwd(*_cast(args, cast_inputs), **_cast(kwar │
│   105 │   │   │   else:                                                      │
│   106 │   │   │   │   return fwd(*args, **kwargs)                            │
│   107 │   return decorate_fwd                                                │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/backend/triton/autograd.py:11 in forward                       │
│                                                                              │
│    8 │   @staticmethod                                                       │
│    9 │   @custom_fwd(cast_inputs=torch.float16)                              │
│   10 │   def forward(ctx, x, qweight, scales, qzeros, g_idx, bits, maxq):    │
│ ❱ 11 │   │   output = tu.triton_matmul(x, qweight, scales, qzeros, g_idx, bi │
│   12 │   │   ctx.save_for_backward(qweight, scales, qzeros, g_idx)           │
│   13 │   │   ctx.bits, ctx.maxq = bits, maxq                                 │
│   14 │   │   output = output.clone()                                         │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/backend/triton/triton_utils.py:246 in triton_matmul            │
│                                                                              │
│   243 │   output = torch.empty((input.shape[0], qweight.shape[1]), device=sc │
│   244 │   grid = lambda META: (                                              │
│   245 │   triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qw │
│ ❱ 246 │   matmul_248_kernel[grid](input, qweight, output,                    │
│   247 │   │   │   │   │   │   │   scales, qzeros, g_idx,                     │
│   248 │   │   │   │   │   │   │   input.shape[0], qweight.shape[1], input.sh │
│   249 │   │   │   │   │   │   │   input.stride(0), input.stride(1),          │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/backend/triton/custom_autotune.py:98 in run                    │
│                                                                              │
│    95 │   │   │   │   # prune configs                                        │
│    96 │   │   │   │   pruned_configs = self.prune_configs(kwargs)            │
│    97 │   │   │   │   bench_start = time.time()                              │
│ ❱  98 │   │   │   │   timings = {config: self._bench(*args, config=config, * │
│    99 │   │   │   │   │   │      for config in pruned_configs}               │
│   100 │   │   │   │   bench_end = time.time()                                │
│   101 │   │   │   │   self.bench_time = bench_end - bench_start              │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/backend/triton/custom_autotune.py:98 in <dictcomp>             │
│                                                                              │
│    95 │   │   │   │   # prune configs                                        │
│    96 │   │   │   │   pruned_configs = self.prune_configs(kwargs)            │
│    97 │   │   │   │   bench_start = time.time()                              │
│ ❱  98 │   │   │   │   timings = {config: self._bench(*args, config=config, * │
│    99 │   │   │   │   │   │      for config in pruned_configs}               │
│   100 │   │   │   │   bench_end = time.time()                                │
│   101 │   │   │   │   self.bench_time = bench_end - bench_start              │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/backend/triton/custom_autotune.py:80 in _bench                 │
│                                                                              │
│    77 │   │   │   │   bench_kwargs = {"quantiles": None}                     │
│    78 │   │   │   else:                                                      │
│    79 │   │   │   │   bench_kwargs = {"percentiles": None}                   │
│ ❱  80 │   │   │   return triton.testing.do_bench(kernel_call, rep=40, **benc │
│    81 │   │   except triton.compiler.OutOfResources:                         │
│    82 │   │   │   return float('inf')                                        │
│    83                                                                        │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/triton/testing.py: │
│ 143 in do_bench                                                              │
│                                                                              │
│   140 │   """                                                                │
│   141 │                                                                      │
│   142 │   # Estimate the runtime of the function                             │
│ ❱ 143 │   fn()                                                               │
│   144 │   torch.cuda.synchronize()                                           │
│   145 │   start_event = torch.cuda.Event(enable_timing=True)                 │
│   146 │   end_event = torch.cuda.Event(enable_timing=True)                   │
│                                                                              │
│ /home/miao/anaconda3/envs/mpt3/lib/python3.9/site-packages/mpttune-0.1.0-py3. │
│ 9.egg/mpttune/backend/triton/custom_autotune.py:71 in kernel_call            │
│                                                                              │
│    68 │   │   │   │   config.pre_hook(self.nargs)                            │
│    69 │   │   │   self.hook(args)                                            │
│    70 │   │   │   #print(args)                                               │
│ ❱  71 │   │   │   self.fn.run(*args, num_warps=config.num_warps,  num_stages │
│    72 │   │                                                                  │
│    73 │   │   try:                                                           │
│    74 │   │   │   # In testings using only 40 reps seems to be close enough  │
│ in matmul_248_kernel:43                                                      │
╰──────────────────────────────────────────────────────────────────────────────╯
ValueError: Pointer argument (at 1) cannot be accessed from Triton (cpu tensor?)
mahao18cm commented 3 months ago

Do you solve this problem?

ai-nikolai commented 1 month ago

@01miaom any updates on this one?