allenai / satlas-super-resolution

Apache License 2.0
201 stars 24 forks source link

Inference on Single Image #36

Closed lokeshsk closed 2 months ago

lokeshsk commented 3 months ago

Hi Piper, It's great to see such an amazing work is publicly available. However, When I run inference on test_set provded it worked perfectly. However, when I downloaded a single sentinel 2 L1C image, and running the inference on it using the weights provided for single image. I am getting. Running inference on 0 images.

infer_example.yml.txt

piperwolters commented 3 months ago

Hello! Thanks for your interest in our work. If the code is not finding any images, that likely means your filepath structure is not the same as the provided test set. Could you add some print statements to your code so we can see what your filepath is versus what the code expects?

lokeshsk commented 3 months ago

Hi, Kindly find the screenshot of the output. Is there are any other pre-prcoessing steps involved which i may have missed which can cause this issue.

Screenshot from 2024-07-06 09-41-59

My input File after passing it through, https://github.com/allenai/satlas-super-resolution#how-to-process-raw-sentinel-2-data delhi_test_rep

However, as it was pesudocode, i modified it, here i am also attaching that for the reference

  import rasterio
  import numpy as np
  from rasterio.warp import reproject, Resampling

  # Path to the Sentinel-2 JP2 file
  tci_jp2_path = './delhi_test.jp2'

  # Open the Sentinel-2 image
  with rasterio.open(tci_jp2_path) as src:
      # Read the entire image as numpy array
      img = src.read()
      meta = src.meta.copy()

      # Define target CRS and resolution
      dst_crs = rasterio.crs.CRS.from_epsg(3857)
      dst_resolution = (9.555, 9.555)

      # Reproject the image
      img_rep = np.zeros((src.count, int(meta['height']), int(meta['width'])), dtype=img.dtype)
      for band_idx in range(src.count):
          reproject(
              source=rasterio.band(src, band_idx + 1),
              destination=img_rep[band_idx],
              src_transform=src.transform,
              src_crs=src.crs,
              dst_transform=src.transform,
              dst_crs=dst_crs,
              resampling=Resampling.bilinear
          )

  # Update metadata with new CRS and resolution
  meta.update({
      'crs': dst_crs,
      'transform': rasterio.transform.from_origin(src.bounds.left, src.bounds.top, dst_resolution[0], dst_resolution[1])
  })

  # Write the reprojected image to a new JP2 file
  output_path = tci_jp2_path.replace('.jp2', '_rep.jp2')
  with rasterio.open(output_path, 'w', **meta) as dst:
      dst.write(img_rep)

  print(f"Reprojected image saved: {output_path}")
piperwolters commented 2 months ago

Ah, I see. That section of code is just meant to document how to project raw Sentinel-2 images to Web-Mercator projection. Then you need to format the imagery as described here. So the imagery should be saved as [num_s2_images * 32, 32, 3] shaped pngs.

I will make the documentation clearer.

lokeshsk commented 2 months ago

Hi, This is my code to format the image as you mentioned:

  import skimage.io
  import numpy as np
  import glob
  import os

  # Path to the directory containing images
  image_dir = '/home/rd/Documents/projects/SATLAS/testing/'  # Replace with your folder path

  # Get a list of all image files in the directory
  image_paths = glob.glob(os.path.join(image_dir, '*.png'))  # You can adjust the file extension if needed

  # Read each image and convert it to a NumPy array
  images = []
  for image_path in image_paths:
      image = skimage.io.imread(image_path)
      image_np = np.array(image)
      images.append(image_np)

  # Stack all images along a new axis to create a single 4D array
  images_array = np.stack(images, axis=0)

  # Print the shape of the combined array
  print(f"Combined images array shape: {images_array.shape[0]}")
  reshape_arr = np.reshape(images_array, (images_array.shape[0]*32, 32, 3))   # reshape this array so we can save it as a png
  skimage.io.imsave('testing.png', reshape_arr)  

I have tried with 2 images as well, and the stacked output generated by the code is as follows: testing

I also played around i.e. tested all permutations of N_LR, N_CH, and model weights in Infer_Example.yml, such as: N_LR: 1 or 2, N_CH, 3 or 6, Model_Weights: esrgan_1S2.pth or esrgan_2S2.pth .

However, When I run the inference I am still getting: Running inference on 0 images

piperwolters commented 2 months ago

The reason the code would print out "Running inference on 0 images" is the glob statement here. This means that your directory structure is not as this glob statement expects and/or your image filepaths do not match the expected ".png".

You want to edit the code or edit your data directory structure so that pngs is a list of filepaths of your Sentinel-2 images. Does that make sense?

lokeshsk commented 2 months ago

Hi, Yeah I figured it out, Sorry I forgot to update over here. The code is working now. Thanks for your quick responses. :) . However, when i run inference on URBAN Areas, the model is unable to generate satisfactory results. Here is an example. LR lr SR sr

However, the linear features are getting worked out. Lodz where model worked great. LR lr SR sr GT (Google Maps - High Res) gmap_GT

However, when I tried another area of lodz it again kind of hallucinated. LR lr SR sr

Could you shed some light on it. And what all the steps will be required to re-train the model on Sentinel Vs SKYSAT.

piperwolters commented 2 months ago

How many Sentinel-2 images did you use as input for these outputs? I have seen plenty of hallucinations in places that look far different from the USA and/or when I only us 1-4 images as input. There is certainly room for improvement in locations that the model has never seen before.

To re-train the model on SKYSAT, you will want to structure a dataset in the same way as our S2NAIP, write a dataset file to replace s2-naip_dataset.py that will load in your intended inputs and outputs ideally in the same format as S2NAIP, write a config file with updated paths to your data and any other variables that need to be changed, and finally start training with python -m ssr.train -config your_config.
I found that using the pretrained satlas-super-res weights transferred to other datasets improved performance, so might be worth trying to start your training from there.