huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.05k stars 27.02k forks source link

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm) #34695

Open ra-MANUJ-an opened 4 days ago

ra-MANUJ-an commented 4 days ago

Reproduction

I am trying to finetune Qwen2-0.5B model on some training data using a multi-GPU setup. The same code (given further below) seems to work in a single-GPU setting (when i set CUDA_VISIBLE_DEVICES=0):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[18], line 4
      2 import torch
      3 torch.autograd.set_detect_anomaly(True)
----> 4 main()

Cell In[14], line 15, in main()
      8 trainer = Trainer(env_params=env_params,
      9                   model_params=model_params,
     10                   optimizer_params=optimizer_params,
     11                   trainer_params=trainer_params)
     13 copy_all_src(trainer.result_folder)
---> 15 trainer.run()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:92, in TSPTrainer.run(self)
     89 self.scheduler.step()
     91 # Train
---> 92 train_score, train_loss = self._train_one_epoch(epoch)
     93 self.result_log.append('train_score', epoch, train_score)
     94 self.result_log.append('train_loss', epoch, train_loss)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:151, in TSPTrainer._train_one_epoch(self, epoch)
    148 remaining = train_num_episode - episode
    149 batch_size = min(self.trainer_params['train_batch_size'], remaining)
--> 151 avg_score, avg_loss = self._train_one_batch(batch_size)
    152 score_AM.update(avg_score, batch_size)
    153 loss_AM.update(avg_loss, batch_size)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:193, in TSPTrainer._train_one_batch(self, batch_size)
    191 state, reward, done = self.env.pre_step()
    192 while not done:
--> 193     selected, prob = self.model.module(state)
    194     # shape: (batch, pomo)
    195     state, reward, done = self.env.step(selected)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:39, in TSPTransformer.forward(self, state)
     37     return self._init_sequence(batch_size, pomo_size)
     38 else:
---> 39     return self._continue_sequence(state, batch_size, pomo_size)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:84, in TSPTransformer._continue_sequence(self, state, batch_size, pomo_size)
     81 state.ninf_mask = state.ninf_mask.to(self.device)
     83 # Get probabilities from decoder
---> 84 probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)
     86 # Select next node
     87 if self.training or self.model_params['eval_type'] == 'softmax':

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:185, in Decoder.forward(self, seq_so_far, inp_mask, ninf_mask)
    182 flat_mask = inp_mask.reshape(batch_size * pomo_size, problem_size)
    184 # Get model outputs
--> 185 outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
    186 logits = outputs.logits
    188 # Get last valid position

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/peft_model.py:1644, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1642     with self._enable_peft_forward_hooks(**kwargs):
   1643         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1644         return self.base_model(
   1645             input_ids=input_ids,
   1646             attention_mask=attention_mask,
   1647             inputs_embeds=inputs_embeds,
   1648             labels=labels,
   1649             output_attentions=output_attentions,
   1650             output_hidden_states=output_hidden_states,
   1651             return_dict=return_dict,
   1652             **kwargs,
   1653         )
   1655 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1656 if attention_mask is not None:
   1657     # concat prompt attention mask

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1170, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
   1167 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1169 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1170 outputs = self.model(
   1171     input_ids=input_ids,
   1172     attention_mask=attention_mask,
   1173     position_ids=position_ids,
   1174     past_key_values=past_key_values,
   1175     inputs_embeds=inputs_embeds,
   1176     use_cache=use_cache,
   1177     output_attentions=output_attentions,
   1178     output_hidden_states=output_hidden_states,
   1179     return_dict=return_dict,
   1180     cache_position=cache_position,
   1181 )
   1183 hidden_states = outputs[0]
   1184 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:901, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    889     layer_outputs = self._gradient_checkpointing_func(
    890         decoder_layer.__call__,
    891         hidden_states,
   (...)
    898         position_embeddings,
    899     )
    900 else:
--> 901     layer_outputs = decoder_layer(
    902         hidden_states,
    903         attention_mask=causal_mask,
    904         position_ids=position_ids,
    905         past_key_value=past_key_values,
    906         output_attentions=output_attentions,
    907         use_cache=use_cache,
    908         cache_position=cache_position,
    909         position_embeddings=position_embeddings,
    910     )
    912 hidden_states = layer_outputs[0]
    914 if use_cache:

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:629, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    626 hidden_states = self.input_layernorm(hidden_states)
    628 # Self Attention
--> 629 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    630     hidden_states=hidden_states,
    631     attention_mask=attention_mask,
    632     position_ids=position_ids,
    633     past_key_value=past_key_value,
    634     output_attentions=output_attentions,
    635     use_cache=use_cache,
    636     cache_position=cache_position,
    637     position_embeddings=position_embeddings,
    638 )
    639 hidden_states = residual + hidden_states
    641 # Fully Connected

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:506, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings)
    495     return super().forward(
    496         hidden_states=hidden_states,
    497         attention_mask=attention_mask,
   (...)
    501         use_cache=use_cache,
    502     )
    504 bsz, q_len, _ = hidden_states.size()
--> 506 query_states = self.q_proj(hidden_states)
    507 key_states = self.k_proj(hidden_states)
    508 value_states = self.v_proj(hidden_states)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/tuners/lora/layer.py:572, in Linear.forward(self, x, *args, **kwargs)
    570     result = self.base_layer(x, *args, **kwargs)
    571 else:
--> 572     result = self.base_layer(x, *args, **kwargs)
    573     torch_result_dtype = result.dtype
    574     for active_adapter in self.active_adapters:

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
    124 def forward(self, input: Tensor) -> Tensor:
--> 125     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

Code for the above error is given below:

Trainer.py

import torch
from logging import getLogger
from torch.nn.parallel import DataParallel
from TSPEnvQuant import TSPEnv as Env
from TSPTransformerModelQuant_b import TSPTransformer as Model

from torch.optim import Adam as Optimizer
from torch.optim.lr_scheduler import MultiStepLR as Scheduler

from utils.utils import *

class TSPTrainer:
    def __init__(self,
                 env_params,
                 model_params,
                 optimizer_params,
                 trainer_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.optimizer_params = optimizer_params
        self.trainer_params = trainer_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()
        self.result_log = LogData()

        # cuda
        USE_CUDA = self.trainer_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.trainer_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')

        # Main Components
        self.model = Model(**self.model_params)
        if USE_CUDA and torch.cuda.device_count() > 1:
            self.logger.info(f"Using {torch.cuda.device_count()} GPUs!")
            self.model = DataParallel(self.model)
        self.model = self.model.to(device)

        self.env = Env(**self.env_params)
        self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
        self.scheduler = Scheduler(self.optimizer, **self.optimizer_params['scheduler'])

        # Restore
        self.start_epoch = 1
        model_load = trainer_params['model_load']
        if model_load['enable']:
            checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
            checkpoint = torch.load(checkpoint_fullname, map_location=device)
            # Handle loading state dict for DataParallel
            if isinstance(self.model, DataParallel):
                # If saved model wasn't using DataParallel but current model is
                if not any(key.startswith('module.') for key in checkpoint['model_state_dict'].keys()):
                    new_state_dict = {'module.' + k: v for k, v in checkpoint['model_state_dict'].items()}
                    self.model.load_state_dict(new_state_dict)
                else:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
            else:
                # If saved model was using DataParallel but current model isn't
                if any(key.startswith('module.') for key in checkpoint['model_state_dict'].keys()):
                    new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
                    self.model.load_state_dict(new_state_dict)
                else:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
            self.start_epoch = 1 + model_load['epoch']
            self.result_log.set_raw_data(checkpoint['result_log'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.last_epoch = model_load['epoch']-1
            self.logger.info('Saved Model Loaded !!')

        # utility
        self.time_estimator = TimeEstimator()

    def run(self):
        self.time_estimator.reset(self.start_epoch)
        for epoch in range(self.start_epoch, self.trainer_params['epochs']+1):
            self.logger.info('=================================================================')

            # LR Decay
            self.scheduler.step()

            # Train
            train_score, train_loss = self._train_one_epoch(epoch)
            self.result_log.append('train_score', epoch, train_score)
            self.result_log.append('train_loss', epoch, train_loss)

            ############################
            # Logs & Checkpoint
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
            self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
                epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str))

            all_done = (epoch == self.trainer_params['epochs'])
            model_save_interval = self.trainer_params['logging']['model_save_interval']
            img_save_interval = self.trainer_params['logging']['img_save_interval']

            if epoch > 1:  # save latest images, every epoch
                self.logger.info("Saving log_image")
                image_prefix = '{}/latest'.format(self.result_folder)
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
                                    self.result_log, labels=['train_score'])
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
                                    self.result_log, labels=['train_loss'])

            if all_done or (epoch % model_save_interval) == 0:
                self.logger.info("Saving trained_model")
                checkpoint_dict = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'result_log': self.result_log.get_raw_data()
                }
                torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))

            if all_done or (epoch % img_save_interval) == 0:
                image_prefix = '{}/img/checkpoint-{}'.format(self.result_folder, epoch)
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
                                    self.result_log, labels=['train_score'])
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
                                    self.result_log, labels=['train_loss'])

            if all_done:
                self.logger.info(" *** Training Done *** ")
                self.logger.info("Now, printing log array...")
                util_print_log_array(self.logger, self.result_log)

    def _train_one_epoch(self, epoch):

        score_AM = AverageMeter()
        loss_AM = AverageMeter()

        train_num_episode = self.trainer_params['train_episodes']
        episode = 0
        loop_cnt = 0
        while episode < train_num_episode:

            remaining = train_num_episode - episode
            batch_size = min(self.trainer_params['train_batch_size'], remaining)

            avg_score, avg_loss = self._train_one_batch(batch_size)
            score_AM.update(avg_score, batch_size)
            loss_AM.update(avg_loss, batch_size)

            episode += batch_size

            # Log First 10 Batch, only at the first epoch
            if epoch == self.start_epoch:
                loop_cnt += 1
                if loop_cnt <= 10:
                    self.logger.info('Epoch {:3d}: Train {:3d}/{:3d}({:1.1f}%)  Score: {:.4f},  Loss: {:.4f}'
                                     .format(epoch, episode, train_num_episode, 100. * episode / train_num_episode,
                                             score_AM.avg, loss_AM.avg))

        # Log Once, for each epoch
        self.logger.info('Epoch {:3d}: Train ({:3.0f}%)  Score: {:.4f},  Loss: {:.4f}'
                         .format(epoch, 100. * episode / train_num_episode,
                                 score_AM.avg, loss_AM.avg))

        return score_AM.avg, loss_AM.avg

    def _train_one_batch(self, batch_size):

        # Prep
        ###############################################
        self.model.train()
        self.env.load_problems(batch_size)
        reset_state, _, _ = self.env.reset()
        # Handle pre_forward for DataParallel
        if isinstance(self.model, DataParallel):
            print("Is DataParallel")
            self.model.module.pre_forward(reset_state)
        else:
            self.model.pre_forward(reset_state)

        prob_list = torch.zeros(size=(batch_size, self.env.pomo_size, 0))
        # shape: (batch, pomo, 0~problem)

        # POMO Rollout
        ###############################################
        state, reward, done = self.env.pre_step()
        while not done:
            selected, prob = self.model.module(state)
            # shape: (batch, pomo)
            state, reward, done = self.env.step(selected)
            prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

        # Loss
        ###############################################
        advantage = reward - reward.float().mean(dim=1, keepdims=True)
        # shape: (batch, pomo)
        log_prob = prob_list.log().sum(dim=2)
        # size = (batch, pomo)
        loss = -advantage * log_prob  # Minus Sign: To Increase REWARD
        # shape: (batch, pomo)
        loss_mean = loss.mean()

        # Score
        ###############################################
        max_pomo_reward, _ = reward.max(dim=1)  # get best results from pomo
        score_mean = -max_pomo_reward.float().mean()  # negative sign to make positive value

        # Step & Return
        ###############################################
        self.model.zero_grad()
        loss_mean.backward()
        self.optimizer.step()
        return score_mean.item(), loss_mean.item()
Model.py

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, TaskType
from typing import Optional, Dict, Any, Tuple

class TSPTransformer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.model_params = kwargs
        self.encoder = Encoder(**kwargs)
        self.embedding_size = kwargs.get('embedding_dim', 896)

        # Load the model with LoRA and 4-bit quantization if needed
        self.model = load_model(kwargs)
        self.decoder = Decoder(self.model, **kwargs)

        # Initialize state storage
        self.encoded_nodes = None
        self.seq_so_far = None
        self.input_mask = None
        self.t = None
        self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    def pre_forward(self, reset_state):
        """Initialize model state for new sequence"""
        self.encoded_nodes = self.encoder(reset_state.problems)
        self.problem_size = reset_state.problems.size(1)
        self.batch_size = reset_state.problems.size(0)

    def forward(self, state) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)

        if state.current_node is None:
            return self._init_sequence(batch_size, pomo_size)
        else:
            return self._continue_sequence(state, batch_size, pomo_size)

    def _init_sequence(self, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize sequence state"""
        self.t = 0  # Start at 0 instead of -1

        # Create new tensors instead of modifying in place
        selected = torch.arange(pomo_size, device=self.device).expand(batch_size, pomo_size)
        prob = torch.ones(size=(batch_size, pomo_size), device=self.device)

        # Initialize sequence storage with proper dimensions
        self.seq_so_far = torch.zeros(
            (batch_size, pomo_size, self.problem_size, self.embedding_size),
            device=self.device
        )

        self.input_mask = torch.zeros(
            (batch_size, pomo_size, self.problem_size),
            dtype=torch.bool,
            device=self.device
        )

        return selected, prob

    def _continue_sequence(self, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Continue sequence generation"""
        # Get encoded representation of current node
        encoded_current = self._get_encoded_node(state.current_node)

        # Create new tensor for updated sequence
        new_seq = self.seq_so_far.clone()
        new_seq[:, :, self.t, :] = encoded_current
        self.seq_so_far = new_seq

        # Create new tensor for updated mask
        new_mask = self.input_mask.clone()
        new_mask[:, :, self.t] = True
        self.input_mask = new_mask

        # Move tensors to correct device
        self.seq_so_far = self.seq_so_far.to(self.device)
        self.input_mask = self.input_mask.to(self.device)
        state.ninf_mask = state.ninf_mask.to(self.device)

        # Get probabilities from decoder
        probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)

        # Select next node
        if self.training or self.model_params['eval_type'] == 'softmax':
            selected, prob = self._sample_node(probs, state, batch_size, pomo_size)
        else:
            selected = probs.argmax(dim=2)
            prob = None

        self.t += 1
        return selected, prob

    def _get_encoded_node(self, node_indices: torch.Tensor) -> torch.Tensor:
        """Get encoded representation of nodes safely"""
        batch_size, pomo_size = node_indices.shape
        embedding_dim = self.encoded_nodes.size(2)

        # Create gathering indices
        gather_idx = node_indices[:, :, None].expand(batch_size, pomo_size, embedding_dim)
        gather_idx = gather_idx.to(self.encoded_nodes.device)

        # Gather encoded representations
        return self.encoded_nodes.gather(dim=1, index=gather_idx)

    def _sample_node(self, probs: torch.Tensor, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample next node with retry logic"""
        max_attempts = 100
        for _ in range(max_attempts):
            # Reshape for sampling
            flat_probs = probs.reshape(batch_size * pomo_size, -1)

            # Sample indices
            selected = flat_probs.multinomial(1, replacement=True)
            selected = selected.reshape(batch_size, pomo_size)

            # Calculate probabilities
            prob = probs[state.BATCH_IDX, state.POMO_IDX, selected]
            prob = prob.reshape(batch_size, pomo_size)

            if (prob > 0).all():
                return selected, prob

        raise RuntimeError(f"Failed to sample valid nodes after {max_attempts} attempts")

class Encoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.embedding_dim = kwargs.get('embedding_dim', 896) - 1
        self.embed_layer = nn.Linear(2, self.embedding_dim)
        self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    def forward(self, problems: torch.Tensor) -> torch.Tensor:
        batch_size, problem_size = problems.shape[:2]

        # Create position encodings
        ids = torch.arange(problem_size, device=self.device).expand(batch_size, problem_size)

        # Embed coordinates
        embedded = self.embed_layer(problems.reshape(-1, 2))
        embedded = embedded.reshape(batch_size, problem_size, self.embedding_dim)

        # Concatenate position encodings
        return torch.cat([ids.unsqueeze(-1).float(), embedded], dim=-1)

class Decoder(nn.Module):
    def __init__(self, model: nn.Module, **kwargs):
        super().__init__()
        self.model = model
        self.problem_size = kwargs.get('problem_size', 20)
        self.use_lora = kwargs.get('use_lora', True)

        self._setup_model()

    def _setup_model(self):
        """Configure model architecture"""
        # Modify output size
        self.model.lm_head = nn.Linear(
            self.model.config.hidden_size,
            self.problem_size
        ).to(self.model.device)

        # Apply LoRA if requested
        if self.use_lora:
            lora_config = LoraConfig(
                r=4,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.1,
                bias="none",
                task_type=TaskType.CAUSAL_LM
            )
            self.model = get_peft_model(self.model, lora_config)

    def forward(self, seq_so_far: torch.Tensor, inp_mask: torch.Tensor, ninf_mask: torch.Tensor) -> torch.Tensor:
        batch_size, pomo_size, problem_size, embedding_dim = seq_so_far.shape

        # Reshape inputs
        flat_seq = seq_so_far.reshape(batch_size * pomo_size, problem_size, embedding_dim)
        flat_mask = inp_mask.reshape(batch_size * pomo_size, problem_size)

        # Get model outputs
        outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
        logits = outputs.logits

        # Get last valid position
        last_positions = flat_mask.sum(dim=1).long() - 1

        # Gather logits for last positions
        batch_indices = torch.arange(batch_size * pomo_size, device=logits.device)
        gathered_logits = logits[batch_indices, last_positions]

        # Reshape and apply mask
        logits = gathered_logits.reshape(batch_size, pomo_size, problem_size)
        masked_logits = logits + ninf_mask.float()

        # Return probabilities
        return torch.softmax(masked_logits, dim=2)

def load_model(config: Dict[str, Any]) -> nn.Module:
    """Load model with proper configuration"""
    # print(config)
    device = config.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    if config.get('checkpoint_path'):
        # print('checkpoint_path')
        try:
            return PeftModel.from_pretrained(
                config['model_name'],
                config['checkpoint_path'],
                is_trainable=True
            ).to(device)
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Falling back to base model...")

    print(config)
    # print(config['use_4bit'])
    if config.get('use_4bit', True):
        print('use_4bit')
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_quant_type="nf4",
            llm_int8_threshold=6.0,
            bnb_4bit_use_double_quant=True,
        )
        # print(config['model_name'])
        # print(type(config['model_name']))
        model = AutoModelForCausalLM.from_pretrained(
            config['model_name'],
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            quantization_config=bnb_config
        )
        model = prepare_model_for_kbit_training(model)
        model.config.use_cache = False
    else:
        # print('else')
        model = AutoModelForCausalLM.from_pretrained(
            config['model_name'],
            torch_dtype=torch.float32,
            trust_remote_code=True,
            device_map="auto",
        ).to(device)

    return model

Expected behavior Expected behavior is that the model should train in a multi-GPU setting without throwing any errors. The same script works in single-GPU setting but throws the above error in a multi-GPU setting

LysandreJik commented 1 day ago

cc @SunMarc and @MekkCyber