microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.33k stars 299 forks source link

Adding custom dataset for multiclassification #2148

Open Sandipriz opened 6 days ago

Sandipriz commented 6 days ago

Issue

I am trying to use the MAXAR dataset with a mask for Torchgeo.

Image-image

mask-image

The error I get image

I am new to Torchgeo and not able to feed the data in models that are not preloaded in Torchgeo.

Fix

No response

isaaccorley commented 6 days ago

Hi @Sandipriz It seems that filepath and filepath1 are not RasterDatasets but are just strings. Could you post some additional code?

Sandipriz commented 3 days ago

Now I see the problem. Here is my file path: filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif' filepath1= '/content/drive/My Drive/BCMCA/bcmca_mask.tif'

The resolution for the image is 0.3 and CRS is epsg (32617). I am not able to define it as a raster dataset and mask.

Should it be like raster=function(filepath, crs=naip.crs, res=0.3), because I can see Sentinel2/ChaespeakeDE in place of function for the custom raster dataset tutorial but not sure for WorldView II.

adamjstewart commented 2 days ago

@Sandipriz if I understand correctly, your current code looks like:

filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1 = '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
train_dataset = filepath & filepath1

The correct code would be:

from torchgeo.datasets import RasterDataset

filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1 = '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
image_dataset = RasterDataset(filepath)
mask_dataset = RasterDataset(filepath1)
train_dataset = image_dataset & mask_dataset

There are other steps that may be required, but this should at least get you farther.

Sandipriz commented 1 day ago

I was able to move ahead. I have the intention to patchify the image and make a train and validation dataset for the model.

Once I tried this torchgeo

Another way I tried to patchify the image manually to the size of 256*256 and feed in the model. For this, I moved the images and mask to the train and validation sub-folder.

Number of images in validation images directory: 539 Number of masks in validation mask directory: 539 Number of images in training images directory: 1255 Number of masks in training mask directory: 1255

I got stuck again.

adamjstewart commented 9 hours ago

First, to clarify some confusion, len(dataset) is not the number of possible samples, it's the number of overlapping images. len(sampler) or len(dataloader) is the number of possible samples.

If you tell me where you got stuck or share the code (not a picture of the code) I can try to help.

Sandipriz commented 7 hours ago

Thank you very much for your support.

filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif' filepath1= '/content/drive/My Drive/BCMCA/bcmca_mask.tif'

from torchgeo.datasets import RasterDataset image = RasterDataset(filepath) mask = RasterDataset(filepath1) dataset = image & mask

Check the number of possible samples

num_samples = len(dataset) print(f"Number of possible samples: {num_samples}")

from torchgeo.datasets import RasterDataset from torchgeo.samplers import GridGeoSampler, RandomGeoSampler

Define the size of the patches

patch_size = 256 # Example patch size, you can adjust this

Create a grid sampler to generate patches

sampler = RandomGeoSampler(dataset, size=patch_size, length=20) print(sampler)

train_loader = DataLoader(dataset, sampler=sampler, batch_size=16, collate_fn=stack_samples)

let's print a random element of the training dataset

random_element = np.random.randint(0, len(train_loader)) for idx, sample in enumerate(train_loader): if idx != random_element: continue

# let's select the first sample from the batch
image = sample["image"][0]
target = sample["mask"][0]

# Create a figure and a 1x2 grid of axes
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Plot the first image in the left axis
# select only first 3 bands and cast to uint8
rgb_image = np.transpose(image.numpy().squeeze()[0:3], (1, 2, 0)).astype('uint8')
axes[0].imshow(rgb_image)
axes[0].set_title('Rgb image')

# Plot the labels image in the right axis
target_image = target.numpy().squeeze()
axes[1].imshow(target_image)
axes[1].set_title('Mask')

# Adjust layout to prevent clipping of titles
plt.tight_layout()

# Show the plots
plt.show()

I wanted to check to see if the image and mask are properly arranged in the dataloader.

My image size is 11719x9922 which makes 37 non-overlapping patches of size 256x256. I want to make 20 of those samples as training and the rest as validation datasets.