Orion-AI-Lab / KuroSiwo

Code and data for Kuro Siwo flood mapping dataset
MIT License
40 stars 3 forks source link

Added initial hubconf.py #4

Open Multihuntr opened 9 months ago

Multihuntr commented 9 months ago

To make the models more accessible, I have added a hubconf.py. This requires instantiating a model from the current code base.

With this change, on my fork, I can - from anywhere - get a pretrained SNUNet or Flood_ViT model:

import torch
import numpy as np
snunet = torch.hub.load("Multihuntr/KuroSiwo", "snunet", pretrained=True)
inps = [torch.randn(8, 2, 224, 224) for i in range(2)]
dem = np.random.randn(8, 224, 224)
out = snunet(inps, dem=dem)

flood_vit = torch.hub.load("Multihuntr/KuroSiwo", "vit_decoder", pretrained=True)
inps = [torch.randn(8, 2, 224, 224) for i in range(3)]
out = flood_vit(inps)

Before accepting this pull request, there are some steps you need to take (at the bottom), if you want it to work using torch.hub.load("Orion-AI-Lab/KuroSiwo", "snunet").


  1. To ensure minimal dependencies, I have moved the FineTunerSegmentation model definition from models/model_utilites.py into its own file. Otherwise it would import model_utilities, and thus require several other libraries.
  2. I have maintained the model structure found in the original weights provided (as best I know how), and added a wrapper. The wrapper normalises the input values and standardises the interface. The idea is that the input is simple S1 images as torch.tensors, and the DEM as a np.ndarray.
  3. The wrapper takes a list of images. Rather than choosing whether to concat or not, to split up concat-ed images. I decided that you should always provide a list of images. You do have to be careful to give the correct number of images for the different models.
  4. I have torch.saved just the state_dict for maximum portability.


I am not certain about a few aspects. Can you check:

  1. Did I correctly instantiate SNUNet_ECAM?
  2. (EDIT: This point resolved)
  3. (EDIT: This point should be addressed elsewhere)

Actions needed before accepting

To make this work on the main branch, you will need to:

  1. Export just the state_dict for each.
  2. Upload those state_dicts to a release (recommended by torch.hub)

Then we can update the URL to point to your original codebase and keep it contained.

Multihuntr commented 9 months ago

Based on the newly pushed "dem fix" code, I see that the third channel for SNUNet is likely the slope, not vh/vv. I have updated my code accordingly.

Multihuntr commented 8 months ago

I have now run these models on a subset of the KuroSiwo dataset and got reasonable IoU numbers. Still not 100%, but, like 95% sure I've got this right now.