allenai / satlas-super-resolution

Apache License 2.0
190 stars 24 forks source link

Inference #18

Closed aiden200 closed 4 months ago

aiden200 commented 4 months ago

Hi! I'm trying to perform inference on Sentinel-2 imagery. I am using the esrgan_8S2.pth weights and the basic infer_example.yml for my config file. I'm obtaining my imagery from the sentinel hub, and i've combined the images to get a [number_sentinel2_images * 32, 32, 3] dimension. number_sentinel2_images in my case would be 8. I'm getting some super weird results. Input: output_lr

lr:

lr

Results/Hr: sr

This is my first github issue. Let me know if you need any additional information from me!

piperwolters commented 4 months ago

Hi, thanks for your interest in the project! I responded to your email with a similar question, but make sure you are running inference on 32x32 chunks of the larger Sentinel-2 images, rather than resizing them down to 32x32. The model won't expect images like that.

aiden200 commented 4 months ago

Hi piper, thanks for the response. I got it working, I will close this issue. Thank you very much for your support! high_res_480

AdnanA21 commented 3 months ago

Hi Aiden. Can you please share how did you make it work? I also faced the same issue. My output super resolution is like below: sr

aiden200 commented 3 months ago

Hi, this is what Piper (Code owner) sent me:

You will want to run inference over 32x32 pixel chunks of the resulting image, rather than resizing the whole image to 32x32.

So if you have 8 images from SentinelHub where each image is 1024x1024 pixels, do something like:

# your_images = tensor of shape [8, 3, 1024, 1024]
for x in range(0, 1024, 32):
   for y in range(0, 1024, 32):
      x_y_output = run_inference(your_image[:, :, x:x+32, y:y+32])

Following her method, I wrote my code to chunk the images into 8x32x32xC chunks (assuming you have the 8 image pretrained model), and stich them back together in the end

def combine_images_to_png(image_paths, output_path):
    # Initialize an empty array for the combined image
    combined_image = []

    for i, img_path in enumerate(image_paths):
        # Read each image and place it in the correct position
        # img = imageio.imread(img_path)
        # img = resample_and_convert_to_png(img_path)
        img = Image.open(img_path)

        # img = img.resize((32, 32), Image.Resampling.LANCZOS)
        img_array = np.array(img, dtype=np.uint8)
        H, W, C = img_array.shape
        sq_l = min(H, W)
        sq_l_32_m = sq_l - (sq_l % 32)
        subtract_margin = (sq_l - sq_l_32_m) // 2

        cropped_img = img.crop((subtract_margin, subtract_margin, subtract_margin + sq_l_32_m, subtract_margin + sq_l_32_m))
        img_array = np.array(cropped_img, dtype=np.uint8)
        # print(img_array.shape)
        L = img_array.shape[0]
        process_count = L // 32
        if len(combined_image) == 0:
            for j in range(process_count**2):
                combined_image.append([np.zeros((8 ,32, 32, C), dtype=np.uint8), (0,0)])

        # print(len(combined_image))

        for j in range(process_count):
            for k in range(process_count):
                # print(combined_image[0][i].shape, img_array[0:32, 0:32, :].shape)
                # print(j*process_count+k, j, k)
                combined_image[j*process_count + k][0][i] = img_array[j*32: (j+1)*32, k*32: (k+1)*32, :]
                combined_image[j*process_count + k][1] = [j,k]

        # cropped_img.save("cropped_image.png")

    for i in range(len(combined_image)):
        combined_image[i][0] = np.reshape(combined_image[i][0], (8*32, 32, 3))
        # print(combined_image[i].shape)
        # cropped_img[i].save("cropped_image.png")
        # Save the combined image as PNG
        # imageio.imwrite(f"{output_path}_{combined_image[i][1][0]}-{combined_image[i][1][1]}.png", combined_image[i][0])
        skimage.io.imsave(f"{output_path}_{combined_image[i][1][0]}-{combined_image[i][1][1]}.png", combined_image[i][0])

Stitching back:

def stich_back(new_dir, output_dir):

    sr_images = []  # List to store the image arrays
    directories = next(os.walk(new_dir))[1]
    # directories.sort()

    # Loop through each directory
    for directory in directories:

        file_path = os.path.join(new_dir, directory, 'sr.png')

        if os.path.isfile(file_path):
            img_array = imageio.imread(file_path)

            sr_images.append([img_array, file_path.split("/")[-2]])

    if len(sr_images) == 0:
        print("No images found.")
        return

    image_count = int(math.sqrt(len(sr_images)))
    if image_count**2 != len(sr_images):
        print("Error: The number of images does not form a perfect square.")
        return

    len_portion = sr_images[0][0].shape[0]
    final_img = np.zeros((len_portion*image_count, len_portion*image_count, 3), dtype=np.uint8)

    for i in range(len(sr_images)):
        j = int(sr_images[i][1].split("-")[0])
        k = int(sr_images[i][1].split("-")[1])
        block = sr_images[i][0]
        # Calculate the position in the original image where this block should be placed
        start_row, start_col = j * len_portion, k * len_portion

        # Place the block in the correct position in the reconstructed image
        final_img[start_row:start_row+len_portion, start_col:start_col+len_portion, :] = block

    # imageio.imwrite(output_dir, final_img)
    skimage.io.imsave(output_dir, final_img)

You will need to modify my code to fit your purpose. Hopefully this helps!

AdnanA21 commented 3 months ago

Thanks for your quick response. I have another query. How did you normalize the satellite data?

In my case, I downloaded the satellite raw sentinel 2 L1c data. Then extract the rgb satellite image(band2, band3, band4) from raw data. Then I normalize through dividing by 10000. Then I test with the given model weights. But I do not get proper output. Like sometimes houses are also detected as field or anything else.

Can you share how did you normalize to get the correct output?

aiden200 commented 3 months ago

The inference code normalizes the image. No need to perform any extra steps if you downloaded the images following the instructions

simon-donike commented 3 months ago

Thanks for your quick response. I have another query. How did you normalize the satellite data?

In my case, I downloaded the satellite raw sentinel 2 L1c data. Then extract the rgb satellite image(band2, band3, band4) from raw data. Then I normalize through dividing by 10000. Then I test with the given model weights. But I do not get proper output. Like sometimes houses are also detected as field or anything else.

Can you share how did you normalize to get the correct output?

In my case, I used raw L1C bands and recreated the TCI properties like this: ((img-1000)/3558)*255.

With this preprocessing I got okay results.

AdnanA21 commented 3 months ago

what did you mean by raw L1c bands? Did you use only the TCI band or (B2, B3, B4)? Can you share data processing steps before and after normalization?

Thanks for your quick response. I have another query. How did you normalize the satellite data? In my case, I downloaded the satellite raw sentinel 2 L1c data. Then extract the rgb satellite image(band2, band3, band4) from raw data. Then I normalize through dividing by 10000. Then I test with the given model weights. But I do not get proper output. Like sometimes houses are also detected as field or anything else. Can you share how did you normalize to get the correct output?

In my case, I used raw L1C bands and recreated the TCI properties like this: ((img-1000)/3558)*255.

  • minus 1k to account for the change in dynamic range since 2022
  • /3558 is the factor used for the L1C TCI composition, this brings you to 0..1 range. /10000 will not do the trick, since the model expects the TCI composition. Sen2 TCI info
  • *255 to get to unit8 and save to .png
  • Then in the dataloader it just does /255, no further normalization

With this preprocessing I got okay results.

simon-donike commented 3 months ago

what did you mean by raw L1c bands? Did you use only the TCI band or (B2, B3, B4)? Can you share data processing steps before and after normalization?

Thanks for your quick response. I have another query. How did you normalize the satellite data? In my case, I downloaded the satellite raw sentinel 2 L1c data. Then extract the rgb satellite image(band2, band3, band4) from raw data. Then I normalize through dividing by 10000. Then I test with the given model weights. But I do not get proper output. Like sometimes houses are also detected as field or anything else. Can you share how did you normalize to get the correct output?

In my case, I used raw L1C bands and recreated the TCI properties like this: ((img-1000)/3558)*255.

  • minus 1k to account for the change in dynamic range since 2022
  • /3558 is the factor used for the L1C TCI composition, this brings you to 0..1 range. /10000 will not do the trick, since the model expects the TCI composition. Sen2 TCI info
  • *255 to get to unit8 and save to .png
  • Then in the dataloader it just does /255, no further normalization

With this preprocessing I got okay results.

The same thing you mean by raw data - using the bands diretly instead of the TCI. And the steps I described are what's needed to bring the image into the same normalization domain as TCI, ergo needed for the inference.