IMSY-DKFZ / htc

Semantic organ segmentation for hyperspectral images.
Other
28 stars 5 forks source link

[Question] Loading and using 16-mix precision pretrained model #21

Closed alfieroddan closed 8 months ago

alfieroddan commented 8 months ago

:question: Question

Just a quick question, trying to use a pretrained model to run over a validation set using a custom script (using htc).

I get the following error:

Traceback (most recent call last):
  File "/home/tay/Code/ms-seg/iterate_htc_loader.py", line 150, in <module>
    run_iteration(model, loader, device, config)
  File "/home/tay/Code/ms-seg/iterate_htc_loader.py", line 43, in run_iteration
    torch.isclose(x.abs().sum(dim=1), torch.tensor(1.0, device=x.device), atol=0.1)
RuntimeError: Half did not match Float

The batch["features"] type is float-16 and the torch.tensor is float, so obviously torch.isclose raises an error.

Am I creating the data loaders wrong?

Any help would be much appreciated. Starting to get familiar with HTC, thanks for the hard work creating this package :).

Description

htc_data.py

# htc
from htc import (Config, DataSpecification, StreamDataLoader,
                 HierarchicalSampler, DatasetImageBatch)

# external
from torch.utils.data.sampler import RandomSampler
from pathlib import Path

# https://github.com/IMSY-DKFZ/htc/blob/main/htc/models/image/LightningImage.py
# line 76
def dataset_from_config_paths(**kwargs):
    if kwargs["train"]:
        if kwargs["config"]["input/hierarchical_sampling"]:
            sampler = HierarchicalSampler(kwargs["paths"], kwargs["config"])
        else:
            sampler = RandomSampler(
                kwargs["paths"], replacement=True, 
                num_samples=kwargs["config"]["input/epoch_size"])

        return DatasetImageBatch(sampler=sampler, **kwargs)
    else:
        # We want every image from the validation/test dataset
        sampler = list(range(len(kwargs["paths"])))
        return DatasetImageBatch(sampler=sampler, **kwargs)

def get_loaders(config):
    """
    returns train, val and test loaders from config... 
    """
    # dataspec
    data_spec = DataSpecification.from_config(config)
    # dataspec contains folds for now only use first fold
    folds = list(data_spec.folds.keys())
    fold_name = folds[0]

    # collate fold paths into train, val and test
    train_paths = []  # list of paths
    val_paths = []  # list of lists of paths
    test_paths = []  # list of paths

    # train and test paths
    for name, paths in data_spec.folds[fold_name].items():
        if name.startswith("train"):
               train_paths += paths
        elif name.startswith("val"):
            val_paths.append(paths)

    # train dataset
    train_dataset = dataset_from_config_paths(
         paths=train_paths,
         train=True,
         config=config,
         fold_name=fold_name
    )

    # val datasets
    datasets_val = []  # list of datasets
    val_path_count = 0
    for paths in val_paths:
        val_path_count += len(paths)
        datasets_val.append(
            dataset_from_config_paths(
                 paths=paths,
                 train=False,
                 config=config,
                 fold_name=fold_name
            )
        )

    # test dataset
    with data_spec.activated_test_set():
               test_paths = data_spec.fold_paths(fold_name, "^test")
    test_dataset = dataset_from_config_paths(
         paths=test_paths,
         train=False,
         config=config,
         fold_name=fold_name
    )

    # datasets -> loaders
    train_loader = StreamDataLoader(train_dataset)
    val_loaders = [StreamDataLoader(d) for d in datasets_val]
    test_loader = StreamDataLoader(test_dataset)
    return (train_loader, val_loaders, test_loader) 

custom script

import argparse
import torch
import platform
from htc import Config, ModelImage, LabelMapping
from pathlib import Path
from htc_tools.htc_data import get_loaders
import logging
from tqdm import tqdm
from torch import functional as F

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

def run_iteration(model, loader, device, config):
    # number of classes
    label_map = LabelMapping.from_config(config)
    num_classes = len(label_map.label_indices())

    # to devices
    model = model.to(device)

    for batch in tqdm(loader): 
        # x 
        x = batch["features"].to(device)
        x = x.permute(0, 3, 1, 2)  # Input dimension for UNet needs to be [N, C, H, W]
        # y
        predictions = model(x)

        # htc stuff
        labels = batch["labels"]
        valid_pixels = batch["valid_pixels"]
        used_labels = labels[valid_pixels].unique()
        labels = labels.masked_fill(~valid_pixels, 0)
        labels = F.one_hot(labels, num_classes=num_classes).to(torch.float16)  # [BHWC]
        # Calculate the losses only for the valid pixels
        # Keep the class dimension
        valid_predictions = predictions.permute(0, 2, 3, 1)[valid_pixels]  # (samples, class)
        valid_labels = labels[valid_pixels]  # (samples, class)
        assert valid_predictions.shape == valid_labels.shape, "Invalid shape"

        n_invalid = (~valid_predictions.isfinite()).sum()
        if n_invalid > 0:
            valid_predictions.nan_to_num_()
            logger.log.warning(
                f"Found {n_invalid} invalid values in prediction of the annotated area"
            )
            logger.log_once.warning(
                "nan_to_num will be applied to the predictions but please note that this is only a workaround and no"
                " real solution. It is very likely that the model does not learn properly (this message is not shown"
                " again)"
            )

if __name__ == "__main__":
    # arg parser
    parser = argparse.ArgumentParser(
                    prog='Iterate',
                    description='Iterates over val or test loaders')

    # config path
    parser.add_argument(
         "--fold_dir",
         type=str,
         help="Path to training config (full config not dataset)",
         default="results/training/image/2023-12-13_09-24-20_htc-config/fold_0",
    )
    parser.add_argument(
        "--loader",
        type=str,
        help="val or test",
        default="val",
        choices=['val', 'test'],
    )
    # args
    args = parser.parse_args()

    # device
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
        device = torch.device('cuda:0')
    else:
        print("\n[INFO] GPU not found. Using CPU: {}\n".format(platform.processor()))
        device = torch.device('cpu')

    # pathlib neccessary
    args.fold_dir = Path(args.fold_dir)
    args.config_path =  Path(args.fold_dir) / "config.json"
    args.fold_name = args.fold_dir.name
    args.run_dir = args.fold_dir.parent
    logger.info("Using run_dir: %s and fold: %s"%(args.fold_name, args.run_dir))

    # load config
    config = Config(Path(args.config_path))

    # loaders
    _, val_loaders, test_loader = get_loaders(config)
    assert len(val_loaders)==1, "One val loader currently implemented"
    val_loader = val_loaders[0]
    if args.loader == "val":
        loader = val_loader
        del test_loader
    elif args.loader == "test":
        loader = test_loader
        del val_loaders, val_loader
    else:
        raise Exception("No loader chosen...")
    logger.info("Using %s loader" % args.loader)

    # number of classes
    label_map = LabelMapping.from_config(config)
    args.num_classes = len(label_map.label_indices())

    # number of channles
    args.num_channels = config["input"]["n_channels"]

    # load model
    model = ModelImage.pretrained_model(
        model="image",
        run_folder=args.run_dir.name,
    )
    logger.info("Using model")
    logger.info(model)

    # run iteration over loader
    run_iteration(model, loader, device, config)

htc config.json

{
    "config_name": "default",
    "dataloader_kwargs": {
        "batch_size": 5,
        "num_workers": 1
    },
    "inherits": "models/image/configs/default",
    "input": {
        "annotation_name": [
            "polygon#annotator1",
            "polygon#annotator2",
            "polygon#annotator3"
        ],
        "data_spec": "2fold-dataspec.json",
        "epoch_size": 500,
        "merge_annotations": "union",
        "n_channels": 100,
        "preprocessing": "L1",
        "transforms_gpu": [
            {
                "class": "KorniaTransform",
                "degrees": 45,
                "p": 0.5,
                "padding_mode": "reflection",
                "scale": [
                    0.9,
                    1.1
                ],
                "transformation_name": "RandomAffine",
                "translate": [
                    0.0625,
                    0.0625
                ]
            },
            {
                "class": "KorniaTransform",
                "p": 0.25,
                "transformation_name": "RandomHorizontalFlip"
            },
            {
                "class": "KorniaTransform",
                "p": 0.25,
                "transformation_name": "RandomVerticalFlip"
            }
        ]
    },
    "label_mapping": "htc.tissue_atlas.settings_atlas>label_mapping",
    "lightning_class": "htc.models.image.LightningImage>LightningImage",
    "model": {
        "architecture_kwargs": {
            "encoder_name": "efficientnet-b5",
            "encoder_weights": "imagenet"
        },
        "architecture_name": "Unet",
        "model_name": "ModelImage",
        "pretrained_model": {
            "model": "image",
            "run_folder": "2023-02-08_14-48-02_organ_transplantation_0.8"
        }
    },
    "optimization": {
        "lr_scheduler": {
            "gamma": 0.99,
            "name": "ExponentialLR"
        },
        "optimizer": {
            "lr": 0.001,
            "name": "Adam",
            "weight_decay": 0
        }
    },
    "swa_kwargs": {
        "annealing_epochs": 0
    },
    "trainer_kwargs": {
        "accelerator": "gpu",
        "devices": 1,
        "max_epochs": 100,
        "precision": "16-mixed"
    },
    "validation": {
        "checkpoint_metric": "dice_metric",
        "dataset_index": 0
    }
}
JanSellner commented 8 months ago

Thanks for your interest in our work :-)

If you only want to get the assert working, you need to change it to

torch.isclose(x.abs().sum(dim=1), torch.tensor(1.0, device=x.device, dtype=x.dtype), atol=0.1)

That the features of the batch have dtype float16 is expected and usually what you want.

btw: I recently wrote a SinglePredictor class which can be used to create predictions outside of our usual producer/consumer architecture. Using that, the basic prediction pipeline may look like:

config = Config(args.config_path)
spec = DataSpecification.from_config(config)
paths = spec.paths("val")
dataloader = DatasetImageBatch.batched_iteration(paths, config)
predictor = SinglePredictor(model="image", run_folder="2023-02-08_14-48-02_organ_transplantation_0.8")
for batch in dataloader:
    logits = predictor.predict_batch(batch)["class"]
    logits.shape  # 5, 19, 480, 640
alfieroddan commented 8 months ago

Cool worked like a charm! I'll use the SinglePredictor as my entry point from now on. Thanks.

Just a question about model loading, Why does HTCModel.find_pretrained_run(model, run_folder, path) load the pretrained weights from the backbone when using a model checkpoint?

(env) tay@tay:~/Code/ms-seg$ PATH_Tivita_HeiPorSPECTRAL=/media/tay/4TB/Datasets/HeiPorSPECTRAL PATH_HTC_RESULTS=./results/ python test.py
[INFO][htc.no_duplicates] Found pretrained run in the local results dir at                                          HTCModel.py:477
/home/tay/Code/ms-seg/results/training/image/2023-12-13_09-24-20_htc-config                                                        
[INFO][htc.no_duplicates] Found pretrained run in the local hub dir at                                              HTCModel.py:484
/home/tay/.cache/torch/hub/htc_checkpoints/image/2023-02-08_14-48-02_organ_transplantation_0.8                                     
[INFO][htc] Successfully loaded the pretrained model (2 keys were skipped:                                          HTCModel.py:312
['model.architecture.segmentation_head.0.weight', 'model.architecture.segmentation_head.0.bias']).                                 

It seems odd that it skips keys when loading a model checkpoint, also no logging info on whether a pretrained run has been loaded.

JanSellner commented 8 months ago

I am not entirely sure I understand the question but maybe this clarification helps:

HTCModel.find_pretrained_run(model, run_folder, path) does not load a checkpoint, it only searches for the specified training run.

With my code from above (using the SinglePredictor) there should be no skipped weights as it directly loads the lightning class and then you can use the model as-is (i.e. based on the checkpoint) to make predictions. So, this is a good entrypoint for inference.

If you use the model classes directly, e.g. ModelImage.pretrained_model(), then you are getting a model instance which you can use in your own lightning class for further training. The weights for the segmentation head are skipped because those weights depend on the number of classes you want to use (there is also a n_classes parameter). Hence, they are initialized randomly so that you can kick off your training with your class setting. So, this is a good entrypoint for training.

alfieroddan commented 8 months ago

All understood! Thank you again for your help.

I look forward to experimenting more with your package. A great piece of work :)