Shanka123 / STSN

Slot Transformer Scoring Network
MIT License
6 stars 0 forks source link

Inquiry about Model Accuracy #2

Open xiaohuahuaw opened 4 months ago

xiaohuahuaw commented 4 months ago

Thanks for providing the model.

I'm a beginner who recently used your released model for evaluation testing on the I-RAVEN dataset. I achieved an accuracy of 73.47%, whereas you reported an accuracy of 95.7% in your paper. Could there be an issue with my code? Could you provide me with the correct testing code?

Thanks for your assistance!

Here is a snippet of the code I used:


# Loading model
slot_model = SlotAttentionAutoEncoder((opt.img_size, opt.img_size), opt.num_slots, opt.num_iterations, opt.hid_dim).to(device)
transformer_scoring_model = scoring_model(opt, opt.hid_dim, opt.depth, opt.heads, opt.mlp_dim, opt.num_slots).to(device)
checkpoint = torch.load(opt.model_checkpoint)
slot_model.load_state_dict(checkpoint['slot_model_state_dict'])
transformer_scoring_model.load_state_dict(checkpoint['transformer_scoring_model_state_dict'])
slot_model.eval()
transformer_scoring_model.eval()

print("Batch size:", opt.batch_size)

# Testing the model
log.info('Testing begins...')
all_test_acc = []
for config_idx in configurations:
    test_data = RAVENdataset(opt.path, "test", [config_idx], opt.img_size, opt.num_slots)
    test_dataloader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)

    log.info("Testing dataset: {}".format(figure_configuration_names[config_idx]))
    all_test_acc_config = []
    with torch.no_grad():
        for batch_idx, (img, target) in enumerate(test_dataloader):
            img = img.to(device).float()
            target = target.to(device)
            slots_seq = []
            for idx in range(img.shape[1]):
                recon_combined, _, _, slots, _ = slot_model(img[:, idx], device)
                slots_seq.append(slots)
            given_panels = torch.stack(slots_seq, dim=1)[:, :8]
            answer_panels = torch.stack(slots_seq, dim=1)[:, 8:]
            scores = transformer_scoring_model(given_panels, answer_panels, device)
            pred = scores.argmax(1)
            acc = torch.eq(pred, target).float().mean().item() * 100.0
            all_test_acc.append(acc)
            all_test_acc_config.append(acc)

        # Output accuracy for this dataset
        average_test_accuracy = torch.mean(torch.tensor(all_test_acc_config)).item()
        log.info("Average test accuracy for {}: {:.2f}%".format(figure_configuration_names[config_idx], average_test_accuracy))

# Output average accuracy for all datasets
average_test_accuracy_all = torch.mean(torch.tensor(all_test_acc)).item()
log.info("Average test accuracy for all datasets: {:.2f}%".format(average_test_accuracy_all))
Shanka123 commented 4 months ago

The code snippet looks correct, but did you set img_size to 80?

xiaohuahuaw commented 4 months ago

Thank you very much for your response!

After adjusting img_size to 80, it indeed improved the accuracy as I expected. Regarding this, could you please explain the impact of img_size on accuracy?

Also, I noticed that the keys in the dataset you provided are inconsistent with those generated by the I-RAVEN code I used. Could you assist me in resolving this issue?

Thank you!

Shanka123 commented 4 months ago

The model is trained with an img_size of 80, and so if you use any other value of img_size that would hurt accuracy. I am not sure what you meant by inconsistent keys in the dataset. Can you explain the issue in a bit more detail?

xiaohuahuaw commented 4 months ago

If I train the model with a dataset generated by code, it will throw errors due to inconsistent keys.

Your dataset keys: ['image', 'mask', 'obj_class', 'brightness', 'size', 'LR', 'UD', 'pos', 'N_pos', 'targ', 'format', 'any_arithmetic']

Dataset generated with official code keys: ['target', 'predict', 'image', 'meta_matrix', 'meta_structure', 'meta_target', 'structure']

Shanka123 commented 4 months ago

The relevant keys for training the model are "image" and "target", which is "targ" in the dataset I shared. If you want to train using the dataset generated by official code, you can ignore the "mask" key, which isn't used for training the model. Just comment line 74 of train_slot_transformer_raven.py and all its references in class RAVENdataset. Also replace data['targ'] with data['target'] in line 73.

xiaohuahuaw commented 1 month ago

I recently found that the results vary each time I run the tests with this code, ranging from 95.14% to 95.85%. Is this normal? What could be the reason for this?

Shanka123 commented 1 month ago

I think this variability is normal and is probably due to the slots being randomly initialized from a learned distribution during inference before slot attention is performed.

xiaohuahuaw commented 3 weeks ago

Did you choose the average value as the result in the paper? If I fix the initialization of slots, can this variation be eliminated?

Shanka123 commented 3 weeks ago

I evaluated only once using the best model according to the validation accuracy. You can try fixing the initialization of slots and see if the results are consistent, however depending on the initialization the accuracy might be slightly higher/lower.