MIC-DKFZ / basic_unet_example

An example project of how to use a U-Net for segmentation on medical images with PyTorch.
Apache License 2.0
139 stars 38 forks source link

Question on exp.run() #5

Closed JunMa11 closed 5 years ago

JunMa11 commented 5 years ago

Hi @elpequeno , In run_train_pipeline.py,

    exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
                         # visdomlogger_kwargs={"auto_start": c.start_visdom},
                         loggers={
                             "visdom": ("visdom", {"auto_start": c.start_visdom})
                         }
                         )

    exp.run()
    exp.run_test(setup=False)

1) What do exp.run() and exp.run_test(setup=False) mean? exp is an instance of class UNetExperiment, but I do not find the definition of the member functions run() and run_test() in UNetExperiment.

2) I write a test() function in UNetExperiment.py,

    def test(self):
        print('Test Demo: Implement your test() method here')
        self.elog.print('TEST')
        self.model.test()

        data = None
        test_dice_list = []

        with torch.no_grad():
            for data_batch in self.test_data_loader:
                data = data_batch['data'][0].float().to(self.device)
                target = data_batch['seg'][0].long().to(self.device)

                pred = self.model(data)
                pred_softmax = F.softmax(pred, dim=1)

                loss = self.dice_loss(pred_softmax, target.squeeze())
                print('test dice: ', -loss.item())
                test_dice_list.append(-loss.item())

When the training is finished, how can I call this function to make inference on test data?

3) I find data_augmentation.py in datasets folder, but the augmentation process is not added to training process.

These questions are not urgent. I'm looking forward to your reply after your vacation.

elpequeno commented 5 years ago

Hi @JunMa11,

the UNetExperiment is derived from PytorchExperiment, a class from the TRIXI framework. If you call the exp.run(), you start "the lifecycle" of the Experiment, if you will.

taken from the class documentation of PytorchExperiment:

The basic life cycle of a PytorchExperiment is the same as :class:.Experiment::

        setup()
        prepare()

        for epoch in n_epochs:
            train()
            validate()

        end()

Check the abstract Experiment class of TRIXI for the implementation of run() and run_test(). I like to call the training and testing like this:

      c = get_config()
      exp = UNetExperiment(config=c, name=c.name, n_epochs=c.n_epochs,
                         seed=42, append_rnd_to_name=c.append_rnd_string, globs=globals(),
                         loggers={
                             "visdom": ("visdom", {"auto_start": c.start_visdom}),
                         }
                         )

      exp.run()
      exp.run_test(setup=False)

You can use the "setup" parameter if you want to call your testing without training first. Make sure you load a pre-trained network in your setup to do this.

  1. Can you be more specific? Does the training augmentation work?

Pura vida, André

JunMa11 commented 5 years ago

Hi @elpequeno , Thanks for your help very much.

  1. Can you be more specific? Does the training augmentation work? During training process, it seems that data augmentation is not executed, because I do not find anywhere the data_augmentation.py is called.
elpequeno commented 5 years ago

data_augmentation.py contains the method get_transforms() which is used in NumpydataLoader (datasets/two_dim/NumpyDataLoader.py). Transforms are applied during data loading.

JunMa11 commented 5 years ago

Got it. Thanks for your guidance.