SuperMedIntel / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
1.09k stars 166 forks source link

How to run multi-label segmentation? #82

Open gulubao opened 1 year ago

gulubao commented 1 year ago

I encountered some issues with multi-label segmentation, and I would like to ask for your help.

The demo ISIC dataset has a single label and the demo BRATS dataset has multiple labels but is merged in the class BRATSDataset3D.

I am interested in performing multi-label segmentation on my own dataset, and I am wondering how to set the dataset and model for this purpose.

Could you please provide a demo for multi-label segmentation?

theneao commented 1 year ago

Did you try using v2? I see that the structure diagram of v2 seems to be multi classified, but I don't seem to find any specific modifications in the code to address this issue. I tried to modify V1, but due to limited ability, it is difficult to modify loss and multi class output, but there are no errors reported, but I still do not have multi class ability. Do you currently have any good findings or ideas regarding this issue?

gulubao commented 1 year ago

I am attempting to train on v2. I am not sure if the author adjusted the loss function for multi-class classification, as MSE and VB don't seem to limit the number of categories.

Here are the adjustments I made for multilabel:

  1. Preprocess the mask labels. 1.1 Mark labels in the entire dataset as 0,1,..n. Reassign the mask values in the entire dataset by 0-n. 1.2 When importing into a custom DataSet, normalize mask with mask = mask / n. 1.3 Use transforms.Resize((args.image_size, args.image_size), interpolation=InterpolationMode.NEAREST) for mask resizing. 1.4 Delete torch.where(mask > 0, 1, 0).
  2. In the gaussain_diffusion.training_losses_segmentation function, change res = torch.where(mask > 0, 1, 0) to res = torch.softmax(mask, dim=0).

Due to limited computational resources, I had to reduce the batch_size, so I replaced all nn.BatchNorm2d with nn.InstanceNorm2d.

I am still in the training process, and I am not sure if these adjustments will work.

I hope the author could share the method for multilabel segmentation in Fig. 2 of the paper.

theneao commented 1 year ago

I have made similar adjustments on V1 before, but in the end, I found that the loss calculation predicted noise (which can be modified to predict x0), but the predicted results only had one channel and cannot be predicted for multiple classes.

The main reason, I think, may be that the 'out_channels' in the Unet network definition was not modified before.

But after trying to modify it, I found that the model output, model var_ Values=th. split (model_output, C, dim=1), it is not clear how to allocate the number of channels between the two in multi class scenarios, and it has been found that assigning values to "model_output" using a method similar to "model_output [:,: 0,:]" directly will result in reporting dimension errors in the next loop, although the dimension is correct after the initial run. Of course, it is also possible that my lack of proficiency in learning has caused some mistakes. You can try it yourself.

If it's convenient, you can directly contact me through my homepage email, or let me know your other contact information by email.

theneao commented 1 year ago

For the separation of channel numbers, I don't know why using torch.split()is different from directly using model_output [:,: 0,:,:], but I can only specify the classification ratio using model_output, model_var_values=th.split (model_output, 4, dim=1) or model_output, model_var_values=th.split (model_output, [4,1], dim=1) (I do a 5 classification task, and I think output_channel can be set to 5 or 8). However, continuing with the operation will still result in errors. In

def_ predict_ xstart_ from_ eps(self, x_t, t, eps):
assert x_ t.shape == eps.shape

The dimension error of 'x_t' is still reported. This is a computable forward diffusion image that can be input, and I believe it can be copied to the same channel as EPS for calculation. Currently, there has been no attempt. In addition, after comparing the two versions of the code, it was found that some modifications were made to the network structure related to the V2 paper, and no changes were made to the category.

gulubao commented 1 year ago

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

theneao commented 1 year ago

The results are indeed very poor, and the binary classification performance is also almost poor on my dataset, with poor fine-grained performance, far lower than the common U-NET network. It is unclear what caused it. I feel like I want to give up on this project.

Recently, there have been many segmentation networks based on diffusion models for similar tasks. If you are interested, we can discuss them through private email.

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

lixiang007666 commented 1 year ago

@gulubao, When I was dealing with the sample code of multi-label classification, I encountered some problems, can you communicate with me?

email: lixiang007@std.uestc.edu.cn

theneao commented 1 year ago

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

I think it may be the problem of loss function design

jaceqin commented 1 year ago

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

Have you solved your multi-tab classification yet?

saisusmitha commented 1 year ago

@gulubao @theneao Hi guys, I think the output sample is including image too - I mean it's giving the segmentation of brain border too - Is this the case with you guys too? - Seeing at the result i think it's the same with your outputs too. Kindly let me know and correct me if I am missing something.

I did not try to add channels to do multi-segmentation but tried to generate multiple pixel values representing different classes in the last channel of the current code.

I conducted experiments on BRATS, however, the effect was very poor. My adjustments were in the comment above.

The figure below is the 50th slice in folder slice0001. The figure that looks dark is the GT segmentation. The gray one is the model output.

slice0001_slice50_GT slice0001_slice50_output_ens

agentdr1 commented 1 year ago

any updates on this? also interested in multi class segmentation.

smallboy-code commented 1 year ago

There are some results of multi class segmentation for brats dataset. And I don't konw how to threshold the output mask like the ground turth. 0 2 3 5

thd2020 commented 2 months ago

I am attempting to train on v2. I am not sure if the author adjusted the loss function for multi-class classification, as MSE and VB don't seem to limit the number of categories.

Here are the adjustments I made for multilabel: 1. Preprocess the mask labels. 1.1 Mark labels in the entire dataset as 0,1,..n. Reassign the mask values in the entire dataset by 0-n. 1.2 When importing into a custom DataSet, normalize mask with mask = mask / n. 1.3 Use transforms.Resize((args.image_size, args.image_size), interpolation=InterpolationMode.NEAREST) for mask resizing. 1.4 Delete torch.where(mask > 0, 1, 0). 2. In the gaussain_diffusion.training_losses_segmentation function, change res = torch.where(mask > 0, 1, 0) to res = torch.softmax(mask, dim=0).

Due to limited computational resources, I had to reduce the batch_size, so I replaced all nn.BatchNorm2d with nn.InstanceNorm2d.

I am still in the training process, and I am not sure if these adjustments will work.

I hope the author could share the method for multilabel segmentation in Fig. 2 of the paper.

@gulubao Bro, how you dealing with masks? Multi-channel or one-hot? And do you modify args.in_ch?

Destinycjk commented 1 week ago

There are some results of multi class segmentation for brats dataset. And I don't konw how to threshold the output mask like the ground turth. 0 2 3 5

Your results look great! Could you please explain how you adjusted the source code to achieve multi-class segmentation?