allenai / satlas-super-resolution

Apache License 2.0
220 stars 24 forks source link

Super Resolution Issue: Visible Tile Boundaries #38

Closed MdRanaSarkar closed 3 months ago

MdRanaSarkar commented 4 months ago

I've been experiencing a consistent issue with super-resolution results where the boundaries between separate tiles are clearly visible. Despite following various preprocessing steps and using different model weights, the problem persists. Below are the details of my preprocessing steps and normalization methods.

For TCI Bands (Using 8-S2 Weights) Normalization Steps 1:

  1. Take TCI bands.
  2. Divide TCI band by 10,000.
  3. Multiply the data by 255 and clip to [0, 1].
  4. Stack the data 8 times.
  5. Perform inference with the model.

Normalization Steps 2:

  1. Take TCI bands.
  2. Divide TCI band by 4,095 and clip to [0, 1].
  3. Multiply the data by 255 to convert to [0, 255].
  4. Stack the data 8 times.
  5. Perform inference with the model.

For TCI + Non-TCI Bands (Using 8-S2 Weights) Normalization Steps 1:

  1. Take TCI bands.
  2. Divide TCI band by 10,000 and clip to [0, 1].
  3. Divide non-TCI band by 8,160 and clip to [0, 1].
  4. Concatenate TCI and non-TCI band data.
  5. Multiply normalized data by 255.
  6. Stack the data 8 times.
  7. Perform inference with the 10m-S2-bands model.

Normalization Steps 2:

  1. Take TCI bands.
  2. Divide TCI band by 4,095 and clip to [0, 1].
  3. Divide non-TCI band by 8,160 and clip to [0, 1].
  4. Concatenate TCI and non-TCI band data.
  5. Multiply the data by 255 to convert to [0, 255].
  6. Stack the data 8 times.
  7. Perform inference with the model.

Note: To subset & resample, I’ve used SNAP tools.

For Resampling: Upsampling method: nearest Downsampling method: first Flag Downsampling method: First Reference Band: B02

Additional Attempt: I've also tried using min-max normalization but still faced the same issue.

Is there any best & ideal preprocessing step ? Can someone help me understand why the tile boundaries are so visible and how I can mitigate this problem? Any suggestions or insights would be greatly appreciated. Thank you!

boundary_issue2

boundary_issue

piperwolters commented 4 months ago

Hello,

This is also an apparent issue on the Satlas website, see screenshot below. Screen Shot 2024-07-11 at 10 24 44 AM

We experimented with some post-processing to smooth out the boundaries of each tile, but found that this often over-smoothed and interfered with the semantic accuracy of the super-res outputs. You could try adding a loss to training that enforces some similarity between tile bounds, but we did not get something to work well.

patriksabol commented 3 months ago

I have managed to resolve this issue by running inference on the tile with overlap. For instance, the size of the patch is 256, and I used a stride of 128. Then I convert the RGB output from each patch to LAB color space and compute the average color with weighting.

To elaborate on the weighting process:

  1. Overlap Count: This tracks the number of times each pixel in the large image has been covered by a patch. As we stitch patches back together, each pixel's value is accumulated, and the overlap count is incremented.

  2. Distance Transform Weights: These weights are created based on the distance from the patch edges, using the Euclidean Distance Transform (EDT). The process is as follows:

    • We initialize a weight matrix with the center region set to 1 (covering the overlap area).
    • The scipy.ndimage.distance_transform_edt function is applied to this matrix to compute the distance of each pixel from the nearest zero (edge of the patch).
    • The resulting weights matrix gives higher weights to pixels near the center of the patch and lower weights near the edges. This ensures a smooth transition between overlapping patches.

The final color values are computed by accumulating the weighted LAB color values for each pixel and dividing by the total weight for that pixel. This technique blends the overlapping regions seamlessly and avoids visible seams between patches.

Here is a snippet of the relevant code for applying the weights and stitching the patches together:

def apply_weights(image_shape, chunk_size, overlap):
    weights = np.zeros((chunk_size, chunk_size), dtype=np.float32)
    center = (chunk_size - overlap) // 2
    weights[center:center+overlap, center:center+overlap] = 1
    weights = scipy.ndimage.distance_transform_edt(weights)
    weights /= weights.max()
    return weights

def stitch_in_memory(chunks, img_height, img_width, chunk_size=256, overlap=32):
    stitched_image = np.zeros((img_height, img_width, 3), dtype=np.float32)
    overlap_count = np.zeros((img_height, img_width), dtype=np.float32)
    weights = apply_weights((img_height, img_width), chunk_size, overlap)

    grid_size_y = (img_height - overlap) // (chunk_size - overlap)
    grid_size_x = (img_width - overlap) // (chunk_size - overlap)

    idx = 0
    for i in range(grid_size_y + 1):
        for j in range(grid_size_x + 1):
            load = chunks[idx]
            idx += 1
            y_start = i * (chunk_size - overlap)
            x_start = j * (chunk_size - overlap)
            y_end = min(y_start + chunk_size, img_height)
            x_end = min(x_start + chunk_size, img_width)

            load = load[:y_end - y_start, :x_end - x_start, :]

            lab_load = color.rgb2lab(load)

            for c in range(3):
                stitched_image[y_start:y_end, x_start:x_end, c] += lab_load[..., c] * weights[:y_end-y_start, :x_end-x_start]
            overlap_count[y_start:y_end, x_start:x_end] += weights[:y_end-y_start, :x_end-x_start]

    overlap_count[overlap_count == 0] = 1

    averaged_lab = np.zeros_like(stitched_image)
    for c in range(3):
        averaged_lab[..., c] = stitched_image[..., c] / overlap_count

    stitched_image_rgb = (color.lab2rgb(averaged_lab) * 255).astype(np.uint8)
    return stitched_image_rgb

By following this method, the final super-resolved image is free from noticeable seams, providing a smooth and visually consistent output. Screenshot from 2024-07-31 10-54-14 Screenshot from 2024-07-31 10-54-24

Note: I trained the model to super-resolve SPOT satellite imagery to aerial resolution.

tfriedel commented 6 days ago

Thanks @patriksabol . I implemented your method and it worked well. After that I noticed I can also just process a larger mosaic of s2 tiles directly. No need to process an image in chunks of 32x32 when there is still plenty of gpu memory available.

image

image