fudan-zvg / SETR

[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers
MIT License
1.05k stars 149 forks source link

Questions about feature visualization (Fig. 5 & Fig. 9) #32

Closed BebDong closed 2 years ago

BebDong commented 3 years ago

Given the encoder of Vit-Large-Patches16 and input size of 3x480x480, the output feature maps of any layer Z should be 1024x30x30 (reshaped). How to map these 1024 features to an RGB space for visualization? Are the feature maps directly upsampled to the original image size during the visualization process?

I didn't find any related codes in this repo.

sixiaozheng commented 2 years ago

We use PCA to compress 1024-dim feature map to 3-dim. A part of codes are given:

def test(dataloader, model, log):
    model.eval()

    n_b = len(dataloader)

    outputs = []
    indices = []
    HWs= []
    for b_i, (images_rgb, annotations) in enumerate(dataloader):
        log.info("Start generating pca for sequence {}.".format(b_i))
        images_rgb = [r.cuda() for r in images_rgb]
        annotations = [q.cuda() for q in annotations]

        N = len(images_rgb)
        pca = dec.PCA(3)

        for i in range(N):
            rgb = images_rgb[i]
            anno = annotations[i]

            _, _, H, W = anno.size()

            with torch.no_grad():
                output = model.module.feature_extraction(rgb)
                B,C,H,W = output.size()
                output = output.cpu().numpy()
                output = output.reshape(C, H*W).T
                outputs.append(output)
                indices.append((b_i, i))
                HWs.append((H,W))

    N = len(outputs)
    log.info("Total %d feature maps. Generating PCA images." % N)
    outputs = np.concatenate(outputs, 0)
    outputs_pca = pca.fit_transform(outputs)
    outputs_pca = (outputs_pca - outputs_pca.min()) / (outputs_pca.max() - outputs_pca.min())

    for i in range(N):
        H,W = HWs[i]
        image = outputs_pca[:H*W,:].reshape(H,W,3)
        os.makedirs(args.savepath+'/S{:02d}'.format(indices[i][0]), exist_ok=True)
        plt.imsave(args.savepath+"/S{:02d}/F{:03d}".format(indices[i][0], indices[i][1]),image)
        outputs_pca = outputs_pca[H*W:]