Closed AbishekRajVG closed 10 months ago
All modified and coverable lines are covered by tests :white_check_mark:
Comparison is base (
4a041ae
) 99.85% compared to head (130ade2
) 99.85%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
This PR tracks the changes as per suggestions from @mostafajahanifar in PR https://github.com/TissueImageAnalytics/tiatoolbox/pull/635
Suggestion 1
To follow the usual convention of moving model first and then using DataParallelism, I would suggest improving this function like below:
This will also avoid unnecessary overhead of DataParallel is there is only one GPU available.
Again, this can be integrated as a method into the ModelABC class. I mean, it should already has to method inherited from nn.Module. However, if we need torch.nn.DataParallel, we can replace that to method with this one. Then users can call: my_model.to(device)
Suggestion 2
why not move this function into the ModelABC class as a method? So, users can load model weights for our models just like they do with normal Pytorch models? is it possible something like below:
my_model.load_weights_from_path(path) or my_model.load(path)
I assume because ModelABC is inheriting from nn.module, it should already have load_state_dict method.