TissueImageAnalytics / tiatoolbox

Computational Pathology Toolbox developed by TIA Centre, University of Warwick.
https://warwick.ac.uk/tia
Other
375 stars 77 forks source link

♻️ Update `model_to()` and `load_torch_model()` methods in `ModelABC` #733

Closed AbishekRajVG closed 10 months ago

AbishekRajVG commented 11 months ago

This PR tracks the changes as per suggestions from @mostafajahanifar in PR https://github.com/TissueImageAnalytics/tiatoolbox/pull/635

Suggestion 1

To follow the usual convention of moving model first and then using DataParallelism, I would suggest improving this function like below:

def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
    """Transfers model to cpu/gpu.
    Args:
        model (torch.nn.Module):
            PyTorch defined model.
        device (str):
            Transfers model to the specified device. Default is "cpu".
    Returns:
        torch.nn.Module:
            The model after being moved to cpu/gpu.
    """
    device = torch.device(device)
    model = model.to(device)

    # If target device is CUDA and more than one GPU is available, use DataParallel
    if device.type == "cuda" and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    return model

This will also avoid unnecessary overhead of DataParallel is there is only one GPU available.

Again, this can be integrated as a method into the ModelABC class. I mean, it should already has to method inherited from nn.Module. However, if we need torch.nn.DataParallel, we can replace that to method with this one. Then users can call: my_model.to(device)

Suggestion 2

why not move this function into the ModelABC class as a method? So, users can load model weights for our models just like they do with normal Pytorch models? is it possible something like below:

my_model.load_weights_from_path(path) or my_model.load(path)

I assume because ModelABC is inheriting from nn.module, it should already have load_state_dict method.

codecov[bot] commented 11 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (4a041ae) 99.85% compared to head (130ade2) 99.85%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## develop #733 +/- ## ======================================== Coverage 99.85% 99.85% ======================================== Files 65 65 Lines 7508 7517 +9 Branches 1460 1460 ======================================== + Hits 7497 7506 +9 Misses 4 4 Partials 7 7 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.