JuliaWolleb / Diffusion-based-Segmentation

This is the official Pytorch implementation of the paper "Diffusion Models for Implicit Image Segmentation Ensembles".
MIT License
272 stars 35 forks source link

How to fix GroupNorm problem when applying in my customed dataset #4

Closed Mentholatum closed 1 year ago

Mentholatum commented 1 year ago

Here I change dataset to my own 3D MRIs, and initialize the model dims=3. But when I run segmentation_train.py, I got this

Traceback (most recent call last): File "/media/data1/jiachuang/projects/courses/test/scripts/segmentation_train.py", line 88, in main() File "/media/data1/jiachuang/projects/courses/test/scripts/segmentation_train.py", line 44, in main TrainLoop( File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/train_util.py", line 187, in run_loop self.run_step(batch, cond) File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/train_util.py", line 208, in run_step sample = self.forward_backward(batch, cond) File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/train_util.py", line 239, in forward_backward losses1 = compute_losses() File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/gaussian_diffusion.py", line 919, in training_losses_segmentation model_output = model(x_t, self._scale_timesteps(t), model_kwargs) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 963, in forward output = self.module(inputs[0], kwargs[0]) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/unet.py", line 657, in forward h = module(h, emb) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/unet.py", line 72, in forward x = layer(x, emb) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/unet.py", line 229, in forward return checkpoint( File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/nn.py", line 139, in checkpoint return func(inputs) File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/unet.py", line 241, in _forward h = self.in_layers(x) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, *kwargs) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 141, in forward input = module(input) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/media/data1/jiachuang/projects/courses/test/guided_diffusion/nn.py", line 19, in forward return super().forward(x.float()).type(x.dtype) File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 268, in forward return F.group_norm( File "/home/jiachuang/anaconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/functional.py", line 2499, in group_norm return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [128] and input of shape [128, 64, 256, 256]** Is it a problem with GroupNorm or the size of my input data? Looking forward to your answer.

def __getitem__(self, item):
    img = np.load(self.train_list[item])
    lab = np.load(self.label_list[item])

    img = np.array(img).astype(np.float32)
    lab = np.array(lab).astype(np.float32)

    return torch.FloatTensor(img), torch.FloatTensor(lab)

data size : [32,256,256]