microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.75k stars 343 forks source link

Adding custom dataset for multiclassification #2148

Open Sandipriz opened 4 months ago

Sandipriz commented 4 months ago

Issue

I am trying to use the MAXAR dataset with a mask for Torchgeo.

Image-image

mask-image

The error I get image

I am new to Torchgeo and not able to feed the data in models that are not preloaded in Torchgeo.

Fix

No response

isaaccorley commented 4 months ago

Hi @Sandipriz It seems that filepath and filepath1 are not RasterDatasets but are just strings. Could you post some additional code?

Sandipriz commented 4 months ago

Now I see the problem. Here is my file path: filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif' filepath1= '/content/drive/My Drive/BCMCA/bcmca_mask.tif'

The resolution for the image is 0.3 and CRS is epsg (32617). I am not able to define it as a raster dataset and mask.

Should it be like raster=function(filepath, crs=naip.crs, res=0.3), because I can see Sentinel2/ChaespeakeDE in place of function for the custom raster dataset tutorial but not sure for WorldView II.

adamjstewart commented 4 months ago

@Sandipriz if I understand correctly, your current code looks like:

filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1 = '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
train_dataset = filepath & filepath1

The correct code would be:

from torchgeo.datasets import RasterDataset

filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1 = '/content/drive/My Drive/BCMCA/bcmca_mask.tif'
image_dataset = RasterDataset(filepath)
mask_dataset = RasterDataset(filepath1)
train_dataset = image_dataset & mask_dataset

There are other steps that may be required, but this should at least get you farther.

Sandipriz commented 4 months ago

I was able to move ahead. I have the intention to patchify the image and make a train and validation dataset for the model.

Once I tried this torchgeo

Another way I tried to patchify the image manually to the size of 256*256 and feed in the model. For this, I moved the images and mask to the train and validation sub-folder.

Number of images in validation images directory: 539 Number of masks in validation mask directory: 539 Number of images in training images directory: 1255 Number of masks in training mask directory: 1255

I got stuck again.

adamjstewart commented 4 months ago

First, to clarify some confusion, len(dataset) is not the number of possible samples, it's the number of overlapping images. len(sampler) or len(dataloader) is the number of possible samples.

If you tell me where you got stuck or share the code (not a picture of the code) I can try to help.

Sandipriz commented 4 months ago

Thank you very much for your support.

filepath = '/content/drive/My Drive/BCMCA/bcmca_3b.tif'
filepath1= '/content/drive/My Drive/BCMCA/bcmca_mask.tif'

from torchgeo.datasets import RasterDataset
image = RasterDataset(filepath)
mask = RasterDataset(filepath1)
dataset = image & mask

#Check the number of possible samples
num_samples = len(dataset)
print(f"Number of possible samples: {num_samples}")

from torchgeo.datasets import RasterDataset
from torchgeo.samplers import GridGeoSampler, RandomGeoSampler

#Define the size of the patches
patch_size = 256  # Example patch size, you can adjust this

#Create a grid sampler to generate patches
sampler = RandomGeoSampler(dataset, size=patch_size, length=20)
print(sampler)

train_loader = DataLoader(dataset, sampler=sampler, batch_size=16, collate_fn=stack_samples)

#let's print a random element of the training dataset
random_element = np.random.randint(0, len(train_loader))
for idx, sample in enumerate(train_loader):
    if idx != random_element:
        continue

    # let's select the first sample from the batch
    image = sample["image"][0]
    target = sample["mask"][0]

    # Create a figure and a 1x2 grid of axes
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))

    # Plot the first image in the left axis
    # select only first 3 bands and cast to uint8
    rgb_image = np.transpose(image.numpy().squeeze()[0:3], (1, 2, 0)).astype('uint8')
    axes[0].imshow(rgb_image)
    axes[0].set_title('Rgb image')

    # Plot the labels image in the right axis
    target_image = target.numpy().squeeze()
    axes[1].imshow(target_image)
    axes[1].set_title('Mask')

    # Adjust layout to prevent clipping of titles
    plt.tight_layout()

    # Show the plots
    plt.show()

I wanted to check to see if the image and mask are properly arranged in the dataloader.

My image size is 11719x9922 which makes 37 non-overlapping patches of size 256x256. I want to make 20 of those samples as training and the rest as validation datasets.

adamjstewart commented 4 months ago

I want to make 20 of those samples as training and the rest as validation datasets.

My suggestion would be to use one of TorchGeo's splitting functions to split the dataset into non-overlapping train and validation datasets. Otherwise, there's no guarantee that the 20 random tiles you sample during training and the 17 random tiles you sample during validation have no overlap.

Also note that RandomGeoSampler makes no guarantees of non-overlapping samples. For that, you probably want GridGeoSampler.

Sandipriz commented 4 months ago

I patchify the images manually.

Number of images- /content/patches256/train/images/: 1254 Number of mask- /content/patches256/train/mask/: 1254

The images have the same name as the mask, consisting size of 256x256 and non-overlapping.

Rest I keep aside for validation. Number of images- /content/patches256/val/images/: 539 Number of mask- /content/patches256/val/mask/: 539

Seems like now I can skip the sampling part and go directly to dataloader part. I don't find in the documentation about how to load the patchified data directly as train data from the directory.

adamjstewart commented 4 months ago

You can either write a NonGeoDataset for loading patchified data, or use a GeoDataset with a sampler. For writing a NonGeoDataset, see a tutorial like: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html.

adamjstewart commented 3 months ago

@Sandipriz do you have any other questions, or can we close this issue?

Sandipriz commented 2 months ago

Three issues:

  1. I am able to feed my 3 bands image but not the 8 bands.
  2. I am lost when I try to augment the data for training.
  3. Basically it's the part of augmentation, as I want to use vegetation index like NDVI, even though I am going through documentation, I am unable to adopt it.
    
    # Define paths to local .tif files
    image_path = "C:/Users/Chris/Documents/BCMCA/bcmca_small_true.tif"
    mask_path = "C:/Users/Chris/Documents/BCMCA/bcmca_mask.tif"

class CustomGeoDataset(Dataset): def init(self, image_path, mask_path, patch_size=256): self.image_path = image_path self.mask_path = mask_path self.patch_size = patch_size

    # Open the images
    self.image_src = rasterio.open(self.image_path)
    self.mask_src = rasterio.open(self.mask_path)
    self.height, self.width = self.image_src.height, self.image_src.width

def __len__(self):
    # We will define an arbitrary number of samples
    return 1000  # Modify as per your requirement

def __getitem__(self, idx):
    x = np.random.randint(0, self.width - self.patch_size)
    y = np.random.randint(0, self.height - self.patch_size)

    image = self.image_src.read(window=rasterio.windows.Window(x, y, self.patch_size, self.patch_size))
    mask = self.mask_src.read(1, window=rasterio.windows.Window(x, y, self.patch_size, self.patch_size))

    # Convert to torch tensors
    image = torch.from_numpy(image).float()
    mask = torch.from_numpy(mask).long()

    # Normalize the image to range [0, 1] if necessary
    image = image / 255.0

    sample = {"image": image, "mask": mask}
    return sample

Create dataset instances

train_dataset = CustomGeoDataset(image_path, mask_path) val_dataset = CustomGeoDataset(image_path, mask_path) # Use the same for validation for now

Set the ratio for splitting into training and validation sets

train_ratio = 0.8 # 80% for training, 20% for validation total_samples = 100 # Samples number after augmentation train_size = int(train_ratio * total_samples) val_size = total_samples - train_size

Use RandomSampler to sample regions from the dataset

train_sampler = RandomSampler(train_dataset, num_samples=train_size, replacement=True) val_sampler = RandomSampler(val_dataset, num_samples=val_size, replacement=True)

train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=16) val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=16)

Let's print a random element of the training dataset

random_element = np.random.randint(0, len(train_loader)) for idx, sample in enumerate(train_loader): if idx != random_element: continue

# Let's select the first sample from the batch
image = sample["image"][0]
target = sample["mask"][0]

# Ensure the image has 3 bands (RGB)
if image.shape[0] == 3:
    # Create a figure and a 1x2 grid of axes
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))

    # Plot the first image in the left axis
    # Select only first 3 bands and cast to uint8
    rgb_image = np.transpose(image.numpy().squeeze(), (1, 2, 0))
    axes[0].imshow((rgb_image * 255).astype('uint8'))
    axes[0].set_title('RGB image')

    # Plot the labels image in the right axis
    target_image = target.numpy().squeeze()
    axes[1].imshow(target_image, cmap='gray')
    axes[1].set_title('Mask')

    # Adjust layout to prevent clipping of titles
    plt.tight_layout()

    # Show the plots
    plt.show()
else:
    print("Image does not have 3 bands.")

class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def init(self, in_channels, out_channels, mid_channels=None): super().init() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) )

def forward(self, x):
    return self.double_conv(x)

class Down(nn.Module): """Downscaling with maxpool then double conv""" def init(self, in_channels, out_channels): super().init() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) )

def forward(self, x):
    return self.maxpool_conv(x)

class Up(nn.Module): """Upscaling then double conv""" def init(self, in_channels, out_channels, bilinear=True): super().init() if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
    x1 = self.up(x1)
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]
    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                    diffY // 2, diffY - diffY // 2])
    x = torch.cat([x2, x1], dim=1)
    return self.conv(x)

class OutConv(nn.Module): def init(self, in_channels, out_channels): super(OutConv, self).init() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
    return self.conv(x)

class UNet(nn.Module): def init(self, n_channels, n_classes, bilinear=True): super(UNet, self).init() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear

    self.inc = DoubleConv(n_channels, 64)
    self.down1 = Down(64, 128)
    self.down2 = Down(128, 256)
    self.down3 = Down(256, 512)
    factor = 2 if bilinear else 1
    self.down4 = Down(512, 1024 // factor)
    self.up1 = Up(1024, 512 // factor, bilinear)
    self.up2 = Up(512, 256 // factor, bilinear)
    self.up3 = Up(256, 128 // factor, bilinear)
    self.up4 = Up(128, 64, bilinear)
    self.outc = OutConv(64, n_classes)

def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up4(x, x1)
    logits = self.outc(x)
    return logits

from tqdm import tqdm

Define necessary variables

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') nb_channels = 3 # Number of input channels (RGB) nb_classes = 1 # Number of output classes (binary mask)

Initialize model, loss function, and optimizer

model = UNet(n_channels=nb_channels, n_classes=nb_classes).to(device) criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Training loop

num_epochs = 10 log_dict = {'loss_per_batch': [], 'loss_per_epoch': []} best_loss = float('inf')

for epoch in range(num_epochs): model.train() running_loss = 0.0 for batch in tqdm(train_loader): inputs = batch["image"][:, 0:3, :, :] / 255.0 # Normalize inputs to [0, 1] targets = batch["mask"].unsqueeze(1) # Add a channel dimension

    # Move data to device
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Forward pass
    outputs = model(inputs)
    loss = criterion(outputs, targets.float())  # BCEWithLogitsLoss does not require sigmoid

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    running_loss += loss.item() * inputs.size(0)
    log_dict['loss_per_batch'].append(loss.item())

# Evaluation phase
model.eval()
total_val_loss = 0.0
with torch.no_grad():
    for val_batch in val_loader:
        val_inputs = val_batch["image"][:, 0:3, :, :] / 255.0  # Normalize inputs to [0, 1]
        val_targets = val_batch["mask"].unsqueeze(1)  # Add a channel dimension

        # Move data to device
        val_inputs = val_inputs.to(device)
        val_targets = val_targets.to(device)

        val_outputs = model(val_inputs)
        val_loss = criterion(val_outputs, val_targets.float())
        total_val_loss += val_loss.item()

# Calculate the average loss
average_val_loss = total_val_loss / len(val_loader)
log_dict['loss_per_epoch'].append(average_val_loss)

# Check if current performance is better than the best so far
if average_val_loss < best_loss:
    best_loss = average_val_loss
    # Save the model checkpoint
    torch.save(model.state_dict(), 'best_model.pt')

print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {running_loss / len(train_loader.dataset)}, Validation Loss: {average_val_loss}")

'''

adamjstewart commented 2 months ago

@Sandipriz you are no longer using TorchGeo in your latest code snippet. If you have any questions about TorchGeo, let me know.

Sandipriz commented 2 months ago

I was unable to feed my custom data to the TorchGeo. A suggestion on that would be helpful because the documentation didn't help to ingest the 8-bands image with the mask in the model.

adamjstewart commented 2 months ago

If you upload the failing TorchGeo code and data needed to reproduce the issue, I can take a look.

https://torchgeo.readthedocs.io/en/stable/tutorials/custom_raster_dataset.html gives an example with 4-band Sentinel-2 imagery, it shouldn't be too hard to modify that example to support 8-band imagery.

Sandipriz commented 1 month ago

I don't know where to upload the data for your access.

The image is a large-size single raster with 8 bands and a mask of 8 classes.

I am able to use customizeRasterDataset and combine imagery and its masks. However, I am not sure how to split the data into train, val, and test. If we can, should that be done after using GridGeoSampler or before taking those samples?

image_path = '/content/drive/My Drive/BCMCA_file/bcmca_img.tif' #8band image
mask_path= '/content/drive/My Drive/BCMCA_file/bcmca_mask.tif'

## Customized raster dataset
class CustomRasterDataset(RasterDataset):
    def __init__(self, paths, crs, res, bands, transforms=None, cache=False):
        # Define the all_bands attribute which lists all 8 bands in the dataset
        self.all_bands = [f'band{i+1}' for i in range(8)]  # 8 bands total in the file

        # Ensure that the provided band indices are valid (0-based indexing)
        if not all(0 <= band < len(self.all_bands) for band in bands):
            raise ValueError("Specified bands must be within available bands")

        # Convert band indices to band names
        self.bands = [self.all_bands[band] for band in bands]

        # Call the parent class's constructor with the specified bands
        super().__init__(paths=paths, crs=crs, res=res, bands=self.bands, transforms=transforms, cache=cache)

# Select band indices (0-based indexing) you want to use
selected_bands = [0, 1, 2, 3, 4, 5,6,7]  # For the first 8 bands

# Instantiate the dataset with the selected bands
image = CustomRasterDataset(paths=image_path, crs="EPSG:32617", res=0.3, bands=selected_bands)

#For mask data
selected_bands=[1]
mask = CustomRasterDataset(paths=mask_path, crs="EPSG:32617", res=0.3, bands=selected_bands)

#Combining overlapping dataset
dataset = image & mask

##Generate samplers from the data

from torchgeo.samplers import GridGeoSampler, Units

# Define the size and stride of the patches
patch_size = (512), 512)  # Example size of the patches
patch_stride = (128, 128)  # Example stride between patches

# Create a GridGeoSampler instance for the combined dataset
sampler = GridGeoSampler(
    dataset=dataset,
    size=patch_size,
    stride=patch_stride,
    units=Units.PIXELS  # Use PIXELS if size and stride are in pixel units
)

# Generate patches
patches = [patch for patch in sampler]
print(patches)
print(len(patches))

dataloader =DataLoader(dataset,batch_size=16, sampler=sampler, collate_fn=stack_samples)
print(dataloader)

# let's print a random element of the training dataset
random_element = np.random.randint(0, len(dataloader))
for idx, sample in enumerate(dataloader):
    if idx != random_element:
        continue

    # let's select the first sample from the batch
    image = sample["image"][0]
    target = sample["mask"][0]

    # Create a figure and a 1x2 grid of axes
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))

    # Plot the first image in the left axis
    # select only first 3 bands and cast to uint8
    rgb_image = np.transpose(image.numpy().squeeze()[0:3], (1, 2, 0)).astype('uint8')
    axes[0].imshow(rgb_image)
    axes[0].set_title('Rgb image')

    # Plot the labels image in the right axis
    target_image = target.numpy().squeeze()
    axes[1].imshow(target_image)
    axes[1].set_title('Mask')

    # Adjust layout to prevent clipping of titles
    plt.tight_layout()

    # Show the plots
    plt.show()

After combining the dataset and mask and making it data loader, I wanted to display random images and its mask to make sure everything was correct but it gave me the following error.

IndexError: band index 2 out of range (not in (1,)).

However, all 8 bands are properly loaded.

adamjstewart commented 1 month ago

I don't know where to upload the data for your access.

Google Drive, DropBox, OneDrive, etc.

However, I am not sure how to split the data into train, val, and test. If we can, should that be done after using GridGeoSampler or before taking those samples?

See https://torchgeo.readthedocs.io/en/stable/api/datasets.html#splitting-functions. This should be done before GridGeoSampler.

it gave me the following error.

Can you give me the full stacktrace so I can see where the error occurs?

adamjstewart commented 1 month ago

Thanks for sharing the code and data with me @Sandipriz. The issue is that you are using CustomRasterDataset for both your images and masks. However, only the images should have bands, and the masks should not be is_image = True. You can fix most of your bugs using:

class ImageDataset(RasterDataset):
    is_image = True

class MaskDataset(RasterDataset):
    is_image = False

image = ImageDataset(image_path)
mask = MaskDataset(mask_path)

I would suggest rereading https://torchgeo.readthedocs.io/en/stable/tutorials/custom_raster_dataset.html more carefully and setting all of the class attributes you need for both the image and mask portions of the dataset.

adamjstewart commented 1 month ago

@Sandipriz are there any other issues you're encountering? Trying to close out bug reports.