RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm) #34695
I am trying to finetune Qwen2-0.5B model on some training data using a multi-GPU setup. The same code (given further below) seems to work in a single-GPU setting (when i set CUDA_VISIBLE_DEVICES=0):
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[18], line 4
2 import torch
3 torch.autograd.set_detect_anomaly(True)
----> 4 main()
Cell In[14], line 15, in main()
8 trainer = Trainer(env_params=env_params,
9 model_params=model_params,
10 optimizer_params=optimizer_params,
11 trainer_params=trainer_params)
13 copy_all_src(trainer.result_folder)
---> 15 trainer.run()
File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:92, in TSPTrainer.run(self)
89 self.scheduler.step()
91 # Train
---> 92 train_score, train_loss = self._train_one_epoch(epoch)
93 self.result_log.append('train_score', epoch, train_score)
94 self.result_log.append('train_loss', epoch, train_loss)
File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:151, in TSPTrainer._train_one_epoch(self, epoch)
148 remaining = train_num_episode - episode
149 batch_size = min(self.trainer_params['train_batch_size'], remaining)
--> 151 avg_score, avg_loss = self._train_one_batch(batch_size)
152 score_AM.update(avg_score, batch_size)
153 loss_AM.update(avg_loss, batch_size)
File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:193, in TSPTrainer._train_one_batch(self, batch_size)
191 state, reward, done = self.env.pre_step()
192 while not done:
--> 193 selected, prob = self.model.module(state)
194 # shape: (batch, pomo)
195 state, reward, done = self.env.step(selected)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:39, in TSPTransformer.forward(self, state)
37 return self._init_sequence(batch_size, pomo_size)
38 else:
---> 39 return self._continue_sequence(state, batch_size, pomo_size)
File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:84, in TSPTransformer._continue_sequence(self, state, batch_size, pomo_size)
81 state.ninf_mask = state.ninf_mask.to(self.device)
83 # Get probabilities from decoder
---> 84 probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)
86 # Select next node
87 if self.training or self.model_params['eval_type'] == 'softmax':
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:185, in Decoder.forward(self, seq_so_far, inp_mask, ninf_mask)
182 flat_mask = inp_mask.reshape(batch_size * pomo_size, problem_size)
184 # Get model outputs
--> 185 outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
186 logits = outputs.logits
188 # Get last valid position
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/second/lib/python3.10/site-packages/peft/peft_model.py:1644, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1642 with self._enable_peft_forward_hooks(**kwargs):
1643 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1644 return self.base_model(
1645 input_ids=input_ids,
1646 attention_mask=attention_mask,
1647 inputs_embeds=inputs_embeds,
1648 labels=labels,
1649 output_attentions=output_attentions,
1650 output_hidden_states=output_hidden_states,
1651 return_dict=return_dict,
1652 **kwargs,
1653 )
1655 batch_size = _get_batch_size(input_ids, inputs_embeds)
1656 if attention_mask is not None:
1657 # concat prompt attention mask
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/second/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
196 def forward(self, *args: Any, **kwargs: Any):
--> 197 return self.model.forward(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1170, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
1167 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1169 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1170 outputs = self.model(
1171 input_ids=input_ids,
1172 attention_mask=attention_mask,
1173 position_ids=position_ids,
1174 past_key_values=past_key_values,
1175 inputs_embeds=inputs_embeds,
1176 use_cache=use_cache,
1177 output_attentions=output_attentions,
1178 output_hidden_states=output_hidden_states,
1179 return_dict=return_dict,
1180 cache_position=cache_position,
1181 )
1183 hidden_states = outputs[0]
1184 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:901, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
889 layer_outputs = self._gradient_checkpointing_func(
890 decoder_layer.__call__,
891 hidden_states,
(...)
898 position_embeddings,
899 )
900 else:
--> 901 layer_outputs = decoder_layer(
902 hidden_states,
903 attention_mask=causal_mask,
904 position_ids=position_ids,
905 past_key_value=past_key_values,
906 output_attentions=output_attentions,
907 use_cache=use_cache,
908 cache_position=cache_position,
909 position_embeddings=position_embeddings,
910 )
912 hidden_states = layer_outputs[0]
914 if use_cache:
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:629, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
626 hidden_states = self.input_layernorm(hidden_states)
628 # Self Attention
--> 629 hidden_states, self_attn_weights, present_key_value = self.self_attn(
630 hidden_states=hidden_states,
631 attention_mask=attention_mask,
632 position_ids=position_ids,
633 past_key_value=past_key_value,
634 output_attentions=output_attentions,
635 use_cache=use_cache,
636 cache_position=cache_position,
637 position_embeddings=position_embeddings,
638 )
639 hidden_states = residual + hidden_states
641 # Fully Connected
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:506, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings)
495 return super().forward(
496 hidden_states=hidden_states,
497 attention_mask=attention_mask,
(...)
501 use_cache=use_cache,
502 )
504 bsz, q_len, _ = hidden_states.size()
--> 506 query_states = self.q_proj(hidden_states)
507 key_states = self.k_proj(hidden_states)
508 value_states = self.v_proj(hidden_states)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/second/lib/python3.10/site-packages/peft/tuners/lora/layer.py:572, in Linear.forward(self, x, *args, **kwargs)
570 result = self.base_layer(x, *args, **kwargs)
571 else:
--> 572 result = self.base_layer(x, *args, **kwargs)
573 torch_result_dtype = result.dtype
574 for active_adapter in self.active_adapters:
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/second/lib/python3.10/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
124 def forward(self, input: Tensor) -> Tensor:
--> 125 return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Code for the above error is given below:
Trainer.py
import torch
from logging import getLogger
from torch.nn.parallel import DataParallel
from TSPEnvQuant import TSPEnv as Env
from TSPTransformerModelQuant_b import TSPTransformer as Model
from torch.optim import Adam as Optimizer
from torch.optim.lr_scheduler import MultiStepLR as Scheduler
from utils.utils import *
class TSPTrainer:
def __init__(self,
env_params,
model_params,
optimizer_params,
trainer_params):
# save arguments
self.env_params = env_params
self.model_params = model_params
self.optimizer_params = optimizer_params
self.trainer_params = trainer_params
# result folder, logger
self.logger = getLogger(name='trainer')
self.result_folder = get_result_folder()
self.result_log = LogData()
# cuda
USE_CUDA = self.trainer_params['use_cuda']
if USE_CUDA:
cuda_device_num = self.trainer_params['cuda_device_num']
torch.cuda.set_device(cuda_device_num)
device = torch.device('cuda', cuda_device_num)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
device = torch.device('cpu')
torch.set_default_tensor_type('torch.FloatTensor')
# Main Components
self.model = Model(**self.model_params)
if USE_CUDA and torch.cuda.device_count() > 1:
self.logger.info(f"Using {torch.cuda.device_count()} GPUs!")
self.model = DataParallel(self.model)
self.model = self.model.to(device)
self.env = Env(**self.env_params)
self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
self.scheduler = Scheduler(self.optimizer, **self.optimizer_params['scheduler'])
# Restore
self.start_epoch = 1
model_load = trainer_params['model_load']
if model_load['enable']:
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
# Handle loading state dict for DataParallel
if isinstance(self.model, DataParallel):
# If saved model wasn't using DataParallel but current model is
if not any(key.startswith('module.') for key in checkpoint['model_state_dict'].keys()):
new_state_dict = {'module.' + k: v for k, v in checkpoint['model_state_dict'].items()}
self.model.load_state_dict(new_state_dict)
else:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
# If saved model was using DataParallel but current model isn't
if any(key.startswith('module.') for key in checkpoint['model_state_dict'].keys()):
new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
self.model.load_state_dict(new_state_dict)
else:
self.model.load_state_dict(checkpoint['model_state_dict'])
self.start_epoch = 1 + model_load['epoch']
self.result_log.set_raw_data(checkpoint['result_log'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.last_epoch = model_load['epoch']-1
self.logger.info('Saved Model Loaded !!')
# utility
self.time_estimator = TimeEstimator()
def run(self):
self.time_estimator.reset(self.start_epoch)
for epoch in range(self.start_epoch, self.trainer_params['epochs']+1):
self.logger.info('=================================================================')
# LR Decay
self.scheduler.step()
# Train
train_score, train_loss = self._train_one_epoch(epoch)
self.result_log.append('train_score', epoch, train_score)
self.result_log.append('train_loss', epoch, train_loss)
############################
# Logs & Checkpoint
############################
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str))
all_done = (epoch == self.trainer_params['epochs'])
model_save_interval = self.trainer_params['logging']['model_save_interval']
img_save_interval = self.trainer_params['logging']['img_save_interval']
if epoch > 1: # save latest images, every epoch
self.logger.info("Saving log_image")
image_prefix = '{}/latest'.format(self.result_folder)
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
self.result_log, labels=['train_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
self.result_log, labels=['train_loss'])
if all_done or (epoch % model_save_interval) == 0:
self.logger.info("Saving trained_model")
checkpoint_dict = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'result_log': self.result_log.get_raw_data()
}
torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))
if all_done or (epoch % img_save_interval) == 0:
image_prefix = '{}/img/checkpoint-{}'.format(self.result_folder, epoch)
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
self.result_log, labels=['train_score'])
util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
self.result_log, labels=['train_loss'])
if all_done:
self.logger.info(" *** Training Done *** ")
self.logger.info("Now, printing log array...")
util_print_log_array(self.logger, self.result_log)
def _train_one_epoch(self, epoch):
score_AM = AverageMeter()
loss_AM = AverageMeter()
train_num_episode = self.trainer_params['train_episodes']
episode = 0
loop_cnt = 0
while episode < train_num_episode:
remaining = train_num_episode - episode
batch_size = min(self.trainer_params['train_batch_size'], remaining)
avg_score, avg_loss = self._train_one_batch(batch_size)
score_AM.update(avg_score, batch_size)
loss_AM.update(avg_loss, batch_size)
episode += batch_size
# Log First 10 Batch, only at the first epoch
if epoch == self.start_epoch:
loop_cnt += 1
if loop_cnt <= 10:
self.logger.info('Epoch {:3d}: Train {:3d}/{:3d}({:1.1f}%) Score: {:.4f}, Loss: {:.4f}'
.format(epoch, episode, train_num_episode, 100. * episode / train_num_episode,
score_AM.avg, loss_AM.avg))
# Log Once, for each epoch
self.logger.info('Epoch {:3d}: Train ({:3.0f}%) Score: {:.4f}, Loss: {:.4f}'
.format(epoch, 100. * episode / train_num_episode,
score_AM.avg, loss_AM.avg))
return score_AM.avg, loss_AM.avg
def _train_one_batch(self, batch_size):
# Prep
###############################################
self.model.train()
self.env.load_problems(batch_size)
reset_state, _, _ = self.env.reset()
# Handle pre_forward for DataParallel
if isinstance(self.model, DataParallel):
print("Is DataParallel")
self.model.module.pre_forward(reset_state)
else:
self.model.pre_forward(reset_state)
prob_list = torch.zeros(size=(batch_size, self.env.pomo_size, 0))
# shape: (batch, pomo, 0~problem)
# POMO Rollout
###############################################
state, reward, done = self.env.pre_step()
while not done:
selected, prob = self.model.module(state)
# shape: (batch, pomo)
state, reward, done = self.env.step(selected)
prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)
# Loss
###############################################
advantage = reward - reward.float().mean(dim=1, keepdims=True)
# shape: (batch, pomo)
log_prob = prob_list.log().sum(dim=2)
# size = (batch, pomo)
loss = -advantage * log_prob # Minus Sign: To Increase REWARD
# shape: (batch, pomo)
loss_mean = loss.mean()
# Score
###############################################
max_pomo_reward, _ = reward.max(dim=1) # get best results from pomo
score_mean = -max_pomo_reward.float().mean() # negative sign to make positive value
# Step & Return
###############################################
self.model.zero_grad()
loss_mean.backward()
self.optimizer.step()
return score_mean.item(), loss_mean.item()
Model.py
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, TaskType
from typing import Optional, Dict, Any, Tuple
class TSPTransformer(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.model_params = kwargs
self.encoder = Encoder(**kwargs)
self.embedding_size = kwargs.get('embedding_dim', 896)
# Load the model with LoRA and 4-bit quantization if needed
self.model = load_model(kwargs)
self.decoder = Decoder(self.model, **kwargs)
# Initialize state storage
self.encoded_nodes = None
self.seq_so_far = None
self.input_mask = None
self.t = None
self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def pre_forward(self, reset_state):
"""Initialize model state for new sequence"""
self.encoded_nodes = self.encoder(reset_state.problems)
self.problem_size = reset_state.problems.size(1)
self.batch_size = reset_state.problems.size(0)
def forward(self, state) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size = state.BATCH_IDX.size(0)
pomo_size = state.BATCH_IDX.size(1)
if state.current_node is None:
return self._init_sequence(batch_size, pomo_size)
else:
return self._continue_sequence(state, batch_size, pomo_size)
def _init_sequence(self, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize sequence state"""
self.t = 0 # Start at 0 instead of -1
# Create new tensors instead of modifying in place
selected = torch.arange(pomo_size, device=self.device).expand(batch_size, pomo_size)
prob = torch.ones(size=(batch_size, pomo_size), device=self.device)
# Initialize sequence storage with proper dimensions
self.seq_so_far = torch.zeros(
(batch_size, pomo_size, self.problem_size, self.embedding_size),
device=self.device
)
self.input_mask = torch.zeros(
(batch_size, pomo_size, self.problem_size),
dtype=torch.bool,
device=self.device
)
return selected, prob
def _continue_sequence(self, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Continue sequence generation"""
# Get encoded representation of current node
encoded_current = self._get_encoded_node(state.current_node)
# Create new tensor for updated sequence
new_seq = self.seq_so_far.clone()
new_seq[:, :, self.t, :] = encoded_current
self.seq_so_far = new_seq
# Create new tensor for updated mask
new_mask = self.input_mask.clone()
new_mask[:, :, self.t] = True
self.input_mask = new_mask
# Move tensors to correct device
self.seq_so_far = self.seq_so_far.to(self.device)
self.input_mask = self.input_mask.to(self.device)
state.ninf_mask = state.ninf_mask.to(self.device)
# Get probabilities from decoder
probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)
# Select next node
if self.training or self.model_params['eval_type'] == 'softmax':
selected, prob = self._sample_node(probs, state, batch_size, pomo_size)
else:
selected = probs.argmax(dim=2)
prob = None
self.t += 1
return selected, prob
def _get_encoded_node(self, node_indices: torch.Tensor) -> torch.Tensor:
"""Get encoded representation of nodes safely"""
batch_size, pomo_size = node_indices.shape
embedding_dim = self.encoded_nodes.size(2)
# Create gathering indices
gather_idx = node_indices[:, :, None].expand(batch_size, pomo_size, embedding_dim)
gather_idx = gather_idx.to(self.encoded_nodes.device)
# Gather encoded representations
return self.encoded_nodes.gather(dim=1, index=gather_idx)
def _sample_node(self, probs: torch.Tensor, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample next node with retry logic"""
max_attempts = 100
for _ in range(max_attempts):
# Reshape for sampling
flat_probs = probs.reshape(batch_size * pomo_size, -1)
# Sample indices
selected = flat_probs.multinomial(1, replacement=True)
selected = selected.reshape(batch_size, pomo_size)
# Calculate probabilities
prob = probs[state.BATCH_IDX, state.POMO_IDX, selected]
prob = prob.reshape(batch_size, pomo_size)
if (prob > 0).all():
return selected, prob
raise RuntimeError(f"Failed to sample valid nodes after {max_attempts} attempts")
class Encoder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.embedding_dim = kwargs.get('embedding_dim', 896) - 1
self.embed_layer = nn.Linear(2, self.embedding_dim)
self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def forward(self, problems: torch.Tensor) -> torch.Tensor:
batch_size, problem_size = problems.shape[:2]
# Create position encodings
ids = torch.arange(problem_size, device=self.device).expand(batch_size, problem_size)
# Embed coordinates
embedded = self.embed_layer(problems.reshape(-1, 2))
embedded = embedded.reshape(batch_size, problem_size, self.embedding_dim)
# Concatenate position encodings
return torch.cat([ids.unsqueeze(-1).float(), embedded], dim=-1)
class Decoder(nn.Module):
def __init__(self, model: nn.Module, **kwargs):
super().__init__()
self.model = model
self.problem_size = kwargs.get('problem_size', 20)
self.use_lora = kwargs.get('use_lora', True)
self._setup_model()
def _setup_model(self):
"""Configure model architecture"""
# Modify output size
self.model.lm_head = nn.Linear(
self.model.config.hidden_size,
self.problem_size
).to(self.model.device)
# Apply LoRA if requested
if self.use_lora:
lora_config = LoraConfig(
r=4,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
bias="none",
task_type=TaskType.CAUSAL_LM
)
self.model = get_peft_model(self.model, lora_config)
def forward(self, seq_so_far: torch.Tensor, inp_mask: torch.Tensor, ninf_mask: torch.Tensor) -> torch.Tensor:
batch_size, pomo_size, problem_size, embedding_dim = seq_so_far.shape
# Reshape inputs
flat_seq = seq_so_far.reshape(batch_size * pomo_size, problem_size, embedding_dim)
flat_mask = inp_mask.reshape(batch_size * pomo_size, problem_size)
# Get model outputs
outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
logits = outputs.logits
# Get last valid position
last_positions = flat_mask.sum(dim=1).long() - 1
# Gather logits for last positions
batch_indices = torch.arange(batch_size * pomo_size, device=logits.device)
gathered_logits = logits[batch_indices, last_positions]
# Reshape and apply mask
logits = gathered_logits.reshape(batch_size, pomo_size, problem_size)
masked_logits = logits + ninf_mask.float()
# Return probabilities
return torch.softmax(masked_logits, dim=2)
def load_model(config: Dict[str, Any]) -> nn.Module:
"""Load model with proper configuration"""
# print(config)
device = config.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
if config.get('checkpoint_path'):
# print('checkpoint_path')
try:
return PeftModel.from_pretrained(
config['model_name'],
config['checkpoint_path'],
is_trainable=True
).to(device)
except Exception as e:
print(f"Error loading checkpoint: {e}")
print("Falling back to base model...")
print(config)
# print(config['use_4bit'])
if config.get('use_4bit', True):
print('use_4bit')
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
llm_int8_has_fp16_weight=False,
bnb_4bit_quant_type="nf4",
llm_int8_threshold=6.0,
bnb_4bit_use_double_quant=True,
)
# print(config['model_name'])
# print(type(config['model_name']))
model = AutoModelForCausalLM.from_pretrained(
config['model_name'],
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
model = prepare_model_for_kbit_training(model)
model.config.use_cache = False
else:
# print('else')
model = AutoModelForCausalLM.from_pretrained(
config['model_name'],
torch_dtype=torch.float32,
trust_remote_code=True,
device_map="auto",
).to(device)
return model
Expected behavior
Expected behavior is that the model should train in a multi-GPU setting without throwing any errors. The same script works in single-GPU setting but throws the above error in a multi-GPU setting
Reproduction
I am trying to finetune Qwen2-0.5B model on some training data using a multi-GPU setup. The same code (given further below) seems to work in a single-GPU setting (when i set CUDA_VISIBLE_DEVICES=0):
Code for the above error is given below:
Expected behavior Expected behavior is that the model should train in a multi-GPU setting without throwing any errors. The same script works in single-GPU setting but throws the above error in a multi-GPU setting