Visual-AI / PromptCCD

The official repository for ECCV2024 paper "PromptCCD: Learning Gaussian Mixture Prompt Pool for Continual Category Discovery"
https://visual-ai.github.io/promptccd/
13 stars 1 forks source link

how to solve this problem,thans a lot! #2

Closed gangxu822 closed 2 weeks ago

gangxu822 commented 2 weeks ago

Traceback (most recent call last): File "/opt/nas/n/code/PromptCCD/main.py", line 72, in main() File "/opt/nas/n/code/PromptCCD/main.py", line 59, in main RunContinualTrainer(args, datasets_train, datasets_val, datasets_test) File "/opt/nas/n/code/PromptCCD/tools/trainer.py", line 46, in RunContinualTrainer ).run() File "/opt/nas/n/code/PromptCCD/tools/trainer.py", line 26, in run model = self.ccd_model.fit(self.train_dataloader_i, self.val_dataloader) File "/opt/nas/n/code/PromptCCD/model.py", line 31, in fit model = self.contrastive_model.fit(train_dataloader, val_dataloader) File "/opt/nas/n/code/PromptCCD/models/promptccd_training_w_gmp.py", line 69, in fit self.gmm_prompt.fit(self.args, self.model, train_loader['default'], self.stage_i) File "/opt/nas/n/code/PromptCCD/models/promptccd_utils/gmp_known_K.py", line 52, in fit print('type(model(data.cuda())[x]):',model(data.cuda())['x'].dtype) TypeError: new(): invalid data type 'str'

fcendra commented 2 weeks ago

Hi there,

Thank you for raising this issue. I’ve reviewed the code and it seems to be functioning correctly on my end. The error you’re encountering appears to be related to invalid data type error as shown below:

print('type(model(data.cuda())[x]):',model(data.cuda())['x'].dtype) TypeError: new(): invalid data type 'str'

The output for above code should return torch.float32.

To help me better understand and resolve the problem, could you please provide more details on the following:

  1. Dataset Information: What dataset are you using?
  2. Dataset Setup: How did you construct and set up the dataset?
  3. output tensor: Could you share the output of the following code: print(model(data.cuda())['x'].dtype)

Your input will be useful in fixing the issue. I'm happy to assist you further once I have more information.

Best regards, Fernando

gangxu822 commented 2 weeks ago

1 DATASET:aircraft 2 I downloaded it from the website (https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)and it loaded. I printed it for check 3 error from this code:https://github.com/Visual-AI/PromptCCD/blob/main/models/promptccd_utils/gmp_known_K.py( line 46-47) as follows with torch.no_grad(): feats = model(data.cuda())['x'][:, 0] image

fcendra commented 2 weeks ago

Thank you for the update. It seems the issue might be due to data formatting differences from the dataset you downloaded from an alternative link.

Could you please download the Aircraft dataset from the link we provided: here and follow the instructions in our documentation here to properly download and setup the Aircraft dataset.

This should ensure compatibility and hopefully resolve the error you’re encountering.

Best regards, Fernando

gangxu822 commented 2 weeks ago

image The data I download by myself is the same as the data I download according to your suggestion, but I still report this error image

gangxu822 commented 2 weeks ago

image I changed here. Could this be the cause?

fcendra commented 2 weeks ago

Thank you for the additional information. It appears that changing the model module can indeed cause the problem you’re experiencing. In our method, we have modified the model output’s format, which is likely why you’re encountering this issue when you changed the code.

Moreover, To ensure compatibility, please use the _create_vision_transformer module. This module has been further modified to make the model compatible with prompting, which should also resolve the issue above.

gangxu822 commented 2 weeks ago

this works,thank you!