huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.87k stars 1.09k forks source link

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cpu! (values = values * mask) #1691

Open lzl65825 opened 1 month ago

lzl65825 commented 1 month ago

I use the PPOTrainer on Mixtral with 8 GPUs whose CUDA version is 12.4. Would you happen to have any idea about solving the following issue? (Also, I have updated all python packages)

Here is the error message, I guess the error happens because masks are in the CPU not in the GPU.

  File "xxx.py", line 272, in autoagi_mixtral
    train_stat = ppo_trainer.step([input_ids[0].to('cpu')], [output[0].to('cpu')], reward)
  File ".../anaconda3/envs/test/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File ".../anaconda3/envs/test/lib/python3.9/site-packages/trl/trainer/ppo_trainer.py", line 773, in step
    values, advantages, returns = self.compute_advantages(values, rewards, masks)
  File ".../anaconda3/envs/test/lib/python3.9/site-packages/trl/trainer/ppo_trainer.py", line 1171, in compute_advantages
    values = values * mask
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cpu!

The key codes are shown as follows:

lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
lora_target_modules = ["q_proj", "v_proj"]

lora_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", cache_dir=args.cache_dir,
                                                          load_in_4bit=True,
                                                          device_map="auto", peft_config=lora_config, )

model.gradient_checkpointing_enable()
model.config.use_cache = False

config = PPOConfig(
    model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
    learning_rate=1.41e-5,
    batch_size=1,
    mini_batch_size=1,
)

ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    optimizer=optimizer,
)

input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()

generation_kwargs = dict(temperature=1.0, do_sample=True, top_p=0.9, top_k=40, max_new_tokens=1024)
output = ppo_trainer.generate(torch.squeeze(input_ids), return_prompt=False, **generation_kwargs)

reward = -2.0
reward = [torch.tensor(reward)]
train_stat = ppo_trainer.step([input_ids[0].to('cpu')], [output[0].to('cpu')], reward)
wernerolaf commented 1 month ago

I have the same issue, I am using H100 GPUs.

wernerolaf commented 3 weeks ago

@PPODecorators.empty_device_cache() def batched_forward_pass( self, model: PreTrainedModelWrapper, queries: torch.Tensor, responses: torch.Tensor, model_inputs: dict, return_logits: bool = False, response_masks: Optional[torch.Tensor] = None, ): """ Calculate model outputs in multiple batches.

    Args:
        queries (`torch.LongTensor`):
            List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
        responses (`torch.LongTensor`):
            List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
        return_logits (`bool`, *optional*, defaults to `False`):
            Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
    Returns:
        (tuple):
            - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
                shape (`batch_size`, `response_length`)
            - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
                shape (`batch_size`, `response_length`)
            - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
    """
    bs = len(queries)
    fbs = self.config.mini_batch_size
    all_logprobs = []
    all_logits = []
    all_masks = []
    all_values = []

    model.eval()

    for i in range(math.ceil(bs / fbs)):
        input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
        query_batch = queries[i * fbs : (i + 1) * fbs]
        response_batch = responses[i * fbs : (i + 1) * fbs]
        if response_masks is not None:
            response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
        logits, _, values = model(**input_kwargs)

        if self.is_encoder_decoder:
            input_ids = input_kwargs["decoder_input_ids"]
            attention_mask = input_kwargs["decoder_attention_mask"]
        else:
            input_ids = input_kwargs["input_ids"]
            attention_mask = input_kwargs["attention_mask"]

        logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
        masks = torch.zeros_like(attention_mask)
        masks[:, :-1] = attention_mask[:, 1:]

        for j in range(len(query_batch)):
            if self.is_encoder_decoder:
                # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
                start = 1
                end = attention_mask[j, :].sum() - 1
            else:
                start = len(query_batch[j]) - 1  # logprobs starts from the second query token
                if attention_mask[j, 0] == 0:  # offset left padding
                    start += attention_mask[j, :].nonzero()[0]
                end = start + len(response_batch[j])
                if response_masks is not None:
                    response_masks_batch[j] = torch.cat(
                        (torch.zeros_like(query_batch[j]), response_masks_batch[j])
                    )[1:]

            masks[j, :start] = 0
            masks[j, end:] = 0
            if response_masks is not None:
                masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]

        if return_logits:
            all_logits.append(logits)
        else:
            del logits
        all_values.append(values)
        all_logprobs.append(logprobs)
        all_masks.append(masks)

    return (
        torch.cat(all_logprobs).to(self.current_device),
        torch.cat(all_logits).to(self.current_device)[:, :-1] if return_logits else None,
        torch.cat(all_values).to(self.current_device)[:, :-1],
        torch.cat(all_masks).to(self.current_device)[:, :-1],
    )
wernerolaf commented 3 weeks ago

put to current device explicitly return ( torch.cat(all_logprobs).to(self.current_device), torch.cat(all_logits).to(self.current_device)[:, :-1] if return_logits else None, torch.cat(all_values).to(self.current_device)[:, :-1], torch.cat(all_masks).to(self.current_device)[:, :-1], )

wernerolaf commented 3 weeks ago

@lzl65825

github-actions[bot] commented 14 hours ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.