Thunderbeee / ZSCL

Preventing Zero-Shot Transfer Degradation in Continual Learning of Vision-Language Models
83 stars 6 forks source link

Runtime Error: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward #6

Open YCAyca opened 8 months ago

YCAyca commented 8 months ago

Hello, Im getting the error mentioned in the title. The explanation in torch.nn.functional.cross_entropy function says that we need to give either the class indexes of ground truth classes or class probabilities (which suits to this case since the ground truths are not strict 0 1 labels but predictions coming from the pretrained model). And the code implementation seems to be correct for the second case, but it gives me runtime error so I had to change it with ground truth class indexes which seem to be work well. I dont know if it would have a significant role to decrease the accuracy though. Any idea??

Screenshot from 2024-04-05 18-12-59

Thunderbeee commented 5 months ago

To resolve this error, you need to ensure that the target tensor contains long integers representing the class indices. Here are a few possible solutions: Convert the target tensor to long integers:

import torch

# Assuming your target tensor is 'target_tensor'
target_tensor = target_tensor.long()

Create the target tensor with the correct data type:

import torch

# Assuming your target data is in 'target_data'
target_tensor = torch.tensor(target_data, dtype=torch.long)

Check the data type of your target tensor:

import torch

# Assuming your target tensor is 'target_tensor'
print(target_tensor.dtype)

If the data type is not torch.long, convert it accordingly. By ensuring that the target tensor contains long integers representing the class indices, you should be able to resolve the "Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward" error.