KidsWithTokens / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
979 stars 147 forks source link

Some questions about DPM-solver #131

Open LiuTingWed opened 10 months ago

LiuTingWed commented 10 months ago

Hi Great job! I have some question when i used "DPM-solver=True". 1: Does the batchsize must be 1 when used DPM-solver ? if not in dpm_solver.py def dynamic_thresholding_fn(self, x0, t): """ The dynamic thresholding method. """ dims = x0.dim() p = self.dynamic_thresholding_ratio s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) s = s.item() x0 = torch.clamp(x0, -s, s) / s return x0 the s = s.item() will raise an error

2: Still infer many noise (The rough outline can be seen) using DPM-solver=True When i check origin DPM-solver code i have some question. Why these are extra model forward in dpm_solver.py and origin DPM is not cal = None out = self.model(torch.cat((self.img,x), dim=1).to(dtype = torch.float), t) if isinstance(out, tuple): x, cal = out if return_intermediate: return x, intermediates else: return x, cal I guess it's because of this that there's a lot of noise in the prediction map

Devin-Pi commented 10 months ago

Hi Great job! I have some question when i used "DPM-solver=True". 1: Does the batchsize must be 1 when used DPM-solver ? if not in dpm_solver.py def dynamic_thresholding_fn(self, x0, t): """ The dynamic thresholding method. """ dims = x0.dim() p = self.dynamic_thresholding_ratio s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) s = s.item() x0 = torch.clamp(x0, -s, s) / s return x0 the s = s.item() will raise an error

2: Still infer many noise (The rough outline can be seen) using DPM-solver=True When i check origin DPM-solver code i have some question. Why these are extra model forward in dpm_solver.py and origin DPM is not cal = None out = self.model(torch.cat((self.img,x), dim=1).to(dtype = torch.float), t) if isinstance(out, tuple): x, cal = out if return_intermediate: return x, intermediates else: return x, cal I guess it's because of this that there's a lot of noise in the prediction map

Hi, I have also met this problem. What I found is that the noise sometimes is full of the predicted mask and sometimes not. Have you tried searching for hyperparameters with the instruction of DPM-SOLVER. Looking forward to your reply!

zxk72 commented 2 months ago

@LiuTingWed How to solve it? thank you for your reply!