JuliaWolleb / Diffusion-based-Segmentation

This is the official Pytorch implementation of the paper "Diffusion Models for Implicit Image Segmentation Ensembles".
MIT License
271 stars 35 forks source link

Sample during training problem #52

Closed YI-HAO-SU closed 8 months ago

YI-HAO-SU commented 8 months ago

Recently, I tried using the great code that the author provided to fill my task. I use three-channel images as input and expect it can sample high-quality images according to the single-channel masks. While the sampling output was not as expected, the shape of things can be recognized in images but the color is very different from the source images.

There may exist two main questions, one is whether the model was not trained well, and another one is that the sampling step has bugs.

To make sure the model is trained well and converged, I would like to sample during training. However, the generated images are noisy images. Can anyone help me to solve this problem? By the way, I can not execute VISDOM on my remote server smoothly.

Below is the code I modified to generate for the purpose `

def run_loop(self):
        i = 0
        data_iter = iter(self.dataloader)
        while (
            not self.lr_anneal_steps
            or self.step + self.resume_step < self.lr_anneal_steps
        ):

        try:
                batch, cond = next(data_iter)
        except StopIteration:
                # StopIteration is thrown if dataset ends
                # reinitialize data loader
                data_iter = iter(self.dataloader)
                batch, cond = next(data_iter)

        sample = self.run_step(batch, cond)
        save_image(sample[0], '/work/kevin20307/Difseg/NCKU/model/model_step_resume_1000/img1.png')

        i += 1

        if self.step % self.log_interval == 0:
            logger.dumpkvs()
        if self.step % self.save_interval == 0:
            self.save()
            # Run for a finite amount of time in integration tests.
            if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                return
        self.step += 1
    # Save the last checkpoint if it wasn't already saved.
    if (self.step - 1) % self.save_interval != 0:
        self.save()

def run_step(self, batch, cond):
    batch=th.cat((batch, cond), dim=1)
    cond={}
    # print(batch.shape)
    sample = self.forward_backward(batch, cond)
    # print(sample.shape)
    took_step = self.mp_trainer.optimize(self.opt)
    if took_step:
        self._update_ema()
    self._anneal_lr()
    self.log_step()
    return sample`

And the generated img1 like 圖片

JuliaWolleb commented 8 months ago

Hi

In the training, you always just want to predict the noise epsilon that you subtract from x_t to get x_t-1. So your sample is just the predicted noise. This is why img1.png is noise. If you want to validate during training, you will need to go through all the steps from t=1000, ... , 0 to generate a synthetic image.

I am also not sure about what you want to generate. Do you want to generate a binary segmentation mask for your input image? If so, the color should not be an issue, since you want to have a binary (or multi-class) segmentation.

YI-HAO-SU commented 8 months ago

Thanks for your reply.

In my task, I would like to input a binary mask and generate a color image according to it in the sampling stage. The result is like the below pair. 20-00020-HE_17920_155136

Immune_Cell_output_test

However, it should be more like this one.

20-00020-HE_17920_155136

Can you give me some hints to implement the process of generating a synthetic image during the training stage?

JuliaWolleb commented 8 months ago

Hi, the third image you sent me, how many channels does it have? What you need to do during training is have the third image as Ground truth, and add t steps of noise noise to this one on all channels. This is your noisy ground truth image x_t. Then you concatenate the mask (1st image) as last channel. Then you pass this to the model. The prediction is the noise epsilon (same amount of channels as x_t). If you train it well, the diffusion model will predict an image that follows your training distribution, so similar to image 3. It seems like your diffusion model did not learn to generate real-looking images in the first place. The segmentation mask is just the condition, such the model has a clue where it needs to paint some structures.

YI-HAO-SU commented 8 months ago

Thanks for your response once more. The third image has 3 channels. I used the other images like the third image to serve as Ground truth for training my denoising model. I aim to generate this during the training process in order to monitor the model's convergence. However, I found that my sample results were always awful due to input errors, which did not convolve the right image and condition mask during the sampling stage. With the problem solved, the precious questions have become clearer, and I can now both assess the model's capabilities and generate samples successfully. Thank you very much.