Open dhx20150812 opened 4 months ago
Thanks for your interest and appreciation!
The code and collected data have now been uploaded. Feel free to use them, and let us know if you have any questions or feedback.
感谢开源代码!
我在运行的时候产生了OutOfMemeryError
的报错,但是我是按照readme中的bash跑的,此时per_device_train_batch_size
也设为了1,所以有点奇怪。
我发现目前开源的代码中没有使用半精度微调,是否是这个原因?
It might not be for this reason.
Although the batch_size in the script is set to 1, due to the special training method of RAAT, the actual batch_size during training is 4 because one query corresponds to 4 augmented samples. This might lead to OOM issues (you can check tuner/trainer/raat_processor.py). In the experimental setup described in our paper, we used 4 A100 GPUs with 80G memory each and deepspeed zero2. Here, we suggest that you modify the source code to reduce the types of retrieval noise. For example, one query corresponds to only 2 augmented samples: one Golden retrieval and one Golden retrieval + Random retrieval noise.
Thanks for your prompt reply.
I also used 4 A100 GPUs, each with 80G memory.
I know that 4 augmented samples will cause the batch size to increase by 4 times, which may be the cause of OOM. However, I did not find where you used deepspeed zero2. Am I missing something?
The source code does not provide a deepspeed YAML script. You can use 'accelerate config' command or write a YAML to implement Zero2.
thanks. I will give it a try.
Hi, @calubkk . I am particularly interested in understanding the logic behind the calculation of cross entropy loss in the the compute_loss
function in tuner/trainer/raat_processor.py
adaProcess
class.
Here is the code snippet I am referring to:
score_list = []
suffix_mask_list = []
embeds = []
logits_list = []
mask_list = []
cls_list = []
for batch_index, sub_batch in enumerate(sub_batches):
# local_outputs = model(**sub_batch, output_hidden_states=True, return_dict=True)
lm_logits, loss, value = model(**sub_batch)
cls_logits = value[:, -1]
cls_list.append(cls_logits)
# embeds.append(embedding)
local_logits = lm_logits
local_mask = sub_batch["attention_mask"] & (
~batch["prefix_mask"][:, batch_index, :]
) # [batch, seq_len]
local_labels = batch["labels"][:, batch_index, :]
shift_logits = local_logits[
..., :-1, :
].contiguous() # [batch, seq_len-1, token_num]
# logits_list.append(shift_logits)
shift_logits = F.log_softmax(
shift_logits, dim=2
) # [batch, seq_len-1, token_num]
shift_masks = local_mask[..., :-1] # [batch, seq_len-1]
# mask_list.append(shift_masks)
shift_labels = local_labels[..., 1:].view(
batch_size, -1, 1
) # [batch, seq_len-1, 1]
selected_logits = torch.gather(
input=shift_logits, dim=2, index=shift_labels
).view(
batch_size, -1
) # [batch, seq_len-1]
selected_logits[shift_masks != 1] = 0.0 # [batch, seq_len-1]
sentence_logits = torch.sum(selected_logits, dim=1) # [batch]
sentence_logits = sentence_logits.view(batch_size, 1)
score_list.append(sentence_logits)
suffix_mask_list.append(torch.sum(shift_masks, dim=1).view(batch_size, 1))
sum_scores = torch.cat(score_list, dim=1) # [batch, training_stage]
total_loss = 0
ada_score = sum_scores
_, min_index = ada_score[:, :].min(dim=1)
_, max_index = ada_score[:, :].max(dim=1)
# Model prioritizes the selection of the largest loss to guide subsequent parameter update
min_scores = sum_scores[:, min_index]
sft_scores = min_scores.view(batch_size * 1)
sft_loss = torch.mean(-sft_scores).to(local_logits.dtype)
From my understanding, score_list
collects the scores for each sentence in the sub-batches. These scores are obtained by processing the model's logits, selecting the logits corresponding to the labels, and summing them up. However, I noticed that the actual calculation of cross entropy loss is not explicitly shown in this code snippet.
Could you please clarify if the cross entropy loss is calculated elsewhere in the code? Or is there a specific reason for not including the cross entropy calculation in this part of the code?
Here is a simple demo of cross-entropy implemented in Python.
import torch
import torch.nn.functional as F
torch.manual_seed(1)
def softmax(tensor,dim=1):
tensor = torch.exp(tensor)
sum = torch.sum(tensor,dim).reshape(-1,1)
tensor = tensor/sum
return tensor
def log_softmax(tensor,dim=1):
return torch.log(softmax(tensor,1))
def nll_loss(tensor, label):
one_hot = torch.nn.functional.one_hot(label, 10)
loss = -torch.sum(one_hot * tensor)/tensor.shape[0]
return loss
def cross_entropy(tensor,label):
tensor = log_softmax(tensor,1)
tensor = nll_loss(tensor,label)
return tensor
out = torch.randn([2,10])
label = torch.randint(0,10,[2])
#pytorch impl
print(F.cross_entropy(out, label).item()) # 3.927539587020874
#our impl
print(cross_entropy(out, label).item()) # 3.927539825439453
cross-entropy = log + softmax + nll_loss You can refer to the demo code and compare it with the compute_loss function, which includes the implementation of the cross-entropy function. You might have overlooked the implementation details of the cross-entropy function within compute_loss. If you still have questions, feel free to contact me via email.
Hi, @calubkk , thanks for your nice work!
BTW, when will the code and collected data be released?