allenai / satlas-super-resolution

Apache License 2.0
190 stars 24 forks source link

MuS2 result #17

Closed Laymanpython closed 4 months ago

Laymanpython commented 5 months ago

Hi,Wolters. Did you try to run code in MuS2 dataset? I found when I use RGB bands,the results will be very bad.

Laymanpython commented 5 months ago

9388d1a8db01b550e2f694b39ad5be0 af734ed4cd981c7e7e2a68f0af6bc33

you can see the first figure is the output, the second one is the ground truth.

I concat the rgb channel for LR and HR,and write the dataset like probav dataset.

and I calculate the mu and std of mus2,this is my code:

import os import cv2 import glob import torch import random import torchvision import numpy as np from torch.utils import data as data from osgeo import gdal from basicsr.utils.registry import DATASET_REGISTRY gdal.PushErrorHandler('CPLQuietErrorHandler')

@DATASET_REGISTRY.register() class MUS2RGBDataset(data.Dataset): """ Args: opt (dict): Config for train datasets. It contains the following keys: """

def __init__(self, opt):
    super(MUS2RGBDataset, self).__init__()
    self.opt = opt
    self.split = None
    try:
        self.split = opt['phase']
    except:
        self.split = None
    self.data_root = opt['data_root']

    self.n_lr_images = opt['n_lr_images'] 
    self.use_3d = opt['use_3d'] if 'use_3d' in opt else False

    hr_fps = glob.glob(self.data_root + 'train/*/hr_resized/hr_rgb.tiff')

    # Filter filepaths based on if the split is train or validation.
    if self.split == 'train':
        hr_fps = glob.glob(self.data_root + 'train/*/hr_resized/hr_rgb.tiff')
    elif self.split == 'val':
        hr_fps = glob.glob(self.data_root + 'val/*/hr_resized/hr_rgb.tiff')
    elif self.split == "test":
        hr_fps = glob.glob(self.data_root + 'test/*/hr_resized/hr_rgb.tiff')
    else:
        raise ValueError("Invalid split entered: ", self.split)

    self.datapoints = []
    lr_fps = []
    for hr_fp in hr_fps:
        lrs = []
        for i in range(self.n_lr_images):
            if i < 10:
                lr = hr_fp.replace('hr_resized/hr_rgb.tiff', 'lr_resized/LR00' + str(i)+".tiff")
            else:
                lr = hr_fp.replace('hr_resized/hr_rgb.tiff', 'lr_resized/LR0' + str(i)+".tiff")
            lrs.append(lr)
        self.datapoints.append([hr_fp, lrs])

    self.data_len = len(self.datapoints)
    print("Loaded ", self.data_len, " data pairs for split ", self.split)

def __getitem__(self, index):
    hr_path, lr_paths  = self.datapoints[index]

    hr_im = gdal.Open(hr_path).ReadAsArray()

    hr_im = hr_im *1.0
    hr_mean = np.array([24.55151043,31.51246904,20.59156773])
    hr_std = np.array([4.54098543,7.59749875,7.04354197])
    hr_im = np.transpose(hr_im, (1,2,0))
    hr_im = (hr_im - hr_mean) / hr_std *255.0
    hr_im = np.transpose(hr_im, (2,0,1))

    # hr_im = normalize(hr_im)*255.0

    c,h,w = hr_im.shape

    hr_tensor = torch.tensor(hr_im)

    lr_ims = []
    for lr_path in lr_paths:
        lr_im = gdal.Open(lr_path).ReadAsArray()
        # lr_im = lr_im *1.0
        # lr_im = normalize(lr_im)*255.0
        lr_mean = np.array([439.92772681,628.69735141,623.3446229])
        lr_std = np.array([211.87138845,242.33688323,322.37929056])
        lr_im = np.transpose(lr_im, (1,2,0))
        lr_im = (lr_im - lr_mean)/lr_std *255.0
        lr_im = np.transpose(lr_im, (2,0,1))
        # lr_im = lr_im[:,lr_start_x:lr_start_x+120, lr_start_y:lr_start_y+120 ]
        lr_tensor = torch.tensor(lr_im)
        lr_ims.append(lr_tensor)

    if self.use_3d:
        img_LR = torch.stack(lr_ims)
    else:
        img_LR = torch.cat(lr_ims)

    img_HR = hr_tensor

    return {'hr': img_HR, 'lr': img_LR, 'Index': index}

def __len__(self):
    return self.data_len
Laymanpython commented 5 months ago

and more, hr_im = (hr_im - hr_mean) / hr_std *255.0 because of the model.py will let hr_im and lr_im divide by 255.0.so I multiply 255 in dataset

piperwolters commented 5 months ago

Hello, I have not tried running on MuS2. MuS2 is a bit weird since it is just a validation set, and the imagery differs from any of the datasets I trained models on.

Are you fine-tuning or training a model on MuS2 or using a pretrained model to run inference on it?

Laymanpython commented 5 months ago

Yeah, I split the hr images to 360pixel360pixel patches, and lr images to 120pixel120pixel patches.Then I split the train,val,test set by the ratio 0.8 0.1 and 0.1.

I also gather rgb band together for reading file easily.

Then I train esrgan、HighResNet、SwinIR and other model in MuS2, without any pretrained mode, I found it will lose its origin color.

Laymanpython commented 5 months ago

my intuition is this problem is related to lr bit depth, and the when I use opencv to read the file ,I get a little better result.

5545609db0af1a81941150c6929057c 71eba62ce9d39ba7c400e6775c4407f

but the result still exists fake texture.

Laymanpython commented 4 months ago

and more the result should be opened by envi.I don't know why why the image cropped are so green..

piperwolters commented 4 months ago

I would suggest saving the files with torchvision, as that is the way images are loaded in the dataset. There can be weird artifacts when switching between torchvision, skimage, cv2, etc. MuS2 is quite small, so I'm not convinced that training on a subset of MuS2 will lead to good generalization. Is it able to overfit on the training set?

piperwolters commented 4 months ago

Also cv2 uses a different ordering of RGB than torchvision I believe?