RichardObi / medigan

medigan - A Python Library of Pretrained Generative Models for Medical Image Synthesis
https://medigan.readthedocs.io/en/latest/
MIT License
115 stars 13 forks source link

Feature/return dataloader #9

Closed RichardObi closed 2 years ago

RichardObi commented 2 years ago

Returning a model as torch dataloader and torch dataset.

This PR should make it quicker and easier for users to train their models on the data generated by one of medigan's generative models.

This PR can be tested via:

from matplotlib import pyplot as plt
import numpy as np
from medigan import Generators
generators = Generators()
dataloader = generators.get_as_torch_dataloader(model_id="00004_PIX2PIX_MASKTOMASS_BREAST_MG_SYNTHESIS", num_samples=2)

plt.figure()
f, img_array = plt.subplots(2, len(dataloader))
for batch_idx, data_dict in enumerate(dataloader):
    sample = np.squeeze(data_dict.get("sample"))
    mask = np.squeeze(data_dict.get("mask"))
    img_array[0][batch_idx].imshow(sample, interpolation='nearest', cmap='gray')
    img_array[0][batch_idx].axis('off')
    img_array[1][batch_idx].imshow(mask, interpolation='nearest', cmap='gray')
    img_array[1][batch_idx].axis('off')
plt.savefig('img.png', transparent=True, bbox_inches='tight')
plt.show()

Apart from that, updates to readme.md and introduction of generators.list_models()