Open lzl65825 opened 1 month ago
I have the same issue, I am using H100 GPUs.
@PPODecorators.empty_device_cache() def batched_forward_pass( self, model: PreTrainedModelWrapper, queries: torch.Tensor, responses: torch.Tensor, model_inputs: dict, return_logits: bool = False, response_masks: Optional[torch.Tensor] = None, ): """ Calculate model outputs in multiple batches.
Args:
queries (`torch.LongTensor`):
List of tensors containing the encoded queries, shape (`batch_size`, `query_length`)
responses (`torch.LongTensor`):
List of tensors containing the encoded responses, shape (`batch_size`, `response_length`)
return_logits (`bool`, *optional*, defaults to `False`):
Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption.
Returns:
(tuple):
- all_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses,
shape (`batch_size`, `response_length`)
- all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`)
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
model.eval()
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
logits, _, values = model(**input_kwargs)
if self.is_encoder_decoder:
input_ids = input_kwargs["decoder_input_ids"]
attention_mask = input_kwargs["decoder_attention_mask"]
else:
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
if self.is_encoder_decoder:
# Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
start = 1
end = attention_mask[j, :].sum() - 1
else:
start = len(query_batch[j]) - 1 # logprobs starts from the second query token
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs).to(self.current_device),
torch.cat(all_logits).to(self.current_device)[:, :-1] if return_logits else None,
torch.cat(all_values).to(self.current_device)[:, :-1],
torch.cat(all_masks).to(self.current_device)[:, :-1],
)
put to current device explicitly return ( torch.cat(all_logprobs).to(self.current_device), torch.cat(all_logits).to(self.current_device)[:, :-1] if return_logits else None, torch.cat(all_values).to(self.current_device)[:, :-1], torch.cat(all_masks).to(self.current_device)[:, :-1], )
@lzl65825
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
I use the PPOTrainer on Mixtral with 8 GPUs whose CUDA version is 12.4. Would you happen to have any idea about solving the following issue? (Also, I have updated all python packages)
Here is the error message, I guess the error happens because
masks
are in the CPU not in the GPU.The key codes are shown as follows: