YijinHuang / SSiT

SSiT: Saliency-guided Self-supervised Image Transformer for Diabetic Retinopathy Grading
22 stars 4 forks source link

Attention visualization #14

Closed Wadha-Almattar closed 3 months ago

Wadha-Almattar commented 4 months ago

Hello, I faced a couple of errors when running attn_visualize.py. Here are the errors and the solution.

Issue 1: Lines 115,116:

115 checkpoint = torch.load(args.checkpoint)
116 load_checkpoint(model, checkpoint, checkpoint_key, linear_key)

The checkpoint parameter should be args.checkpoint, and remove line 115, the checkpoint will be loaded in load_checkpoint function.

Issue 2: From line 126:

126 transform = pth_transforms.Compose([
        pth_transforms.Resize((args.image_size, args.image_size)),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize(mean, std),
    ])

Somehow, args.image_size is passed as a list with one element. Here is the updated version to be added before line 126 to double check the type and convert it:

if isinstance(args.image_size, list):
        if len(args.image_size) == 1:
            args.image_size = args.image_size[0]
        elif len(args.image_size) == 2:
            args.image_size = tuple(args.image_size)
        else:
            raise ValueError("image_size list must have one or two elements")

    # Ensure correct types
    assert isinstance(args.image_size, (int, tuple)), "image_size must be an integer or a tuple of integers"
    assert all(isinstance(m, float) for m in mean), "mean values must be floats"
    assert all(isinstance(s, float) for s in std), "std values must be floats"
YijinHuang commented 3 months ago

Thank you for bringing these issues to our attention. We have updated the corresponding code accordingly.