Beckschen / TransUNet

This repository includes the official project of TransUNet, presented in our paper: TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation.
Apache License 2.0
2.26k stars 477 forks source link

TransUnet for RGB-images? #31

Open andife opened 3 years ago

andife commented 3 years ago

Hello, it seems that the code currently only works on grayscale images. II am interested in processing images with 3 channels (RGB). Has anyone already modified the code accordingly? What do I have to pay attention to?

SM-93 commented 3 years ago

I have a question about running the code. Have you actually ran the code, I mean trained the model and tested the model on their images ? And did you get the same or similar results as in the their paper ?

Thanks

Beckschen commented 3 years ago

Hello, it seems that the code currently only works on grayscale images. II am interested in processing images with 3 channels (RGB). Has anyone already modified the code accordingly? What do I have to pay attention to?

@andife Hello, this repo also supports RGB image with 3 channels.

The network is original support 3 channels input (See line 386-387 in vit_seg_modeling.py): if x.size()[1] == 1: x = x.repeat(1,3,1,1)

andife commented 3 years ago

Hello, it seems that I still have problems to prepare the dataset

So at least the class RandomGenerator cannot be used directly, because of x, y = image.shape ValueError: too many values to unpack (expected 2)

I tried the pipeline using it with RGB data, and get the following:

Traceback (most recent call last):
  File "train.py", line 103, in <module>
    trainer[dataset_name](args, net, snapshot_path)
  File "project_TransUNet/TransUNet/trainer.py", line 133, in trainer_owndataset
    for i_batch, sampled_batch in enumerate(trainloader):
  File "anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "anaconda3/lib/python3.8/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "project_TransUNet/TransUNet/datasets/dataset_owndataset.py", line 74, in __getitem__
    sample = self.transform(sample)
  File anaconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 70, in __call__
    img = t(img)
  File project_TransUNet/TransUNet/datasets/dataset_owndataset.py", line 39, in __call__
    x, y   = image.shape
ValueError: too many values to unpack (expected 2)
aneeshgupta42 commented 3 years ago

Hello, it seems that the code currently only works on grayscale images. II am interested in processing images with 3 channels (RGB). Has anyone already modified the code accordingly? What do I have to pay attention to?

@andife Hello, this repo also supports RGB image with 3 channels.

The network is original support 3 channels input (See line 386-387 in vit_seg_modeling.py): if x.size()[1] == 1: x = x.repeat(1,3,1,1)

@Beckschen I'm trying to use this model for RGB images. I removed the random rotations (they seemed buggy for RGB images), and instead now get an error on the lines you have mentioned (386-387 in vit_seg_modeling.py). The error is as follows: RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor image

andife commented 3 years ago

Hello,

Can someone point me to the solution? Does one have working code?

Currently TransUnet expects/uses shape x: torch.Size([12, 1, 224, 224]) for the synapse dataset.

When I tried to use my files with RGB-Channel. I got shape x: torch.Size([12, 1, 3, 736, 736])

Obviously, the dimensions did not fit . Think I have to get rid of the '1' I squeezed the dataset, but then I got the following error:

Traceback (most recent call last):
  File "train.py", line 114, in <module>
    trainer[dataset_name](args, net, snapshot_path)
  File "/home/andife/project_TransUNet/TransUNet/trainer.py", line 223, in trainer_ulm3D
    outputs = model(image_batch)
  File "/home/andife/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andife/project_TransUNet/TransUNet/networks/vit_seg_modeling.py", line 393, in forward
    x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)
  File "/home/andife/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andife/project_TransUNet/TransUNet/networks/vit_seg_modeling.py", line 254, in forward
    embedding_output, features = self.embeddings(input_ids)
  File "/home/andife/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/andife/project_TransUNet/TransUNet/networks/vit_seg_modeling.py", line 163, in forward
    embeddings = x + self.position_embeddings
RuntimeError: The size of tensor a (2116) must match the size of tensor b (196) at non-singleton dimension 1
Some1OutThere commented 3 years ago

Hello, it seems that the code currently only works on grayscale images. II am interested in processing images with 3 channels (RGB). Has anyone already modified the code accordingly? What do I have to pay attention to?

@andife Hello, this repo also supports RGB image with 3 channels. The network is original support 3 channels input (See line 386-387 in vit_seg_modeling.py): if x.size()[1] == 1: x = x.repeat(1,3,1,1)

@Beckschen I'm trying to use this model for RGB images. I removed the random rotations (they seemed buggy for RGB images), and instead now get an error on the lines you have mentioned (386-387 in vit_seg_modeling.py). The error is as follows: RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensorimage

Did you fix this? I am also trying to repeat this for RGB images.

PatrickWilliams44 commented 1 year ago

Hello, I have the same problem. Have you solved it?@Some1OutThere

1183498834 commented 1 year ago

Hello, it seems that I still have problems to prepare the dataset

So at least the class RandomGenerator cannot be used directly, because of x, y = image.shape ValueError: too many values to unpack (expected 2)

I tried the pipeline using it with RGB data, and get the following:

Traceback (most recent call last):
  File "train.py", line 103, in <module>
    trainer[dataset_name](args, net, snapshot_path)
  File "project_TransUNet/TransUNet/trainer.py", line 133, in trainer_owndataset
    for i_batch, sampled_batch in enumerate(trainloader):
  File "anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "anaconda3/lib/python3.8/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "project_TransUNet/TransUNet/datasets/dataset_owndataset.py", line 74, in __getitem__
    sample = self.transform(sample)
  File anaconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 70, in __call__
    img = t(img)
  File project_TransUNet/TransUNet/datasets/dataset_owndataset.py", line 39, in __call__
    x, y   = image.shape
ValueError: too many values to unpack (expected 2)

I have solved this issue. If the image is a RGB image, the image.shape would be a tuple like (h, w, 3), the original code x, y = image.shape is unpacking two elements, but image.shape has three elements. So you can fix it by changing the code like x, y, z = image.shape.

heweapon commented 4 months ago

Hello, it seems that the code currently only works on grayscale images. II am interested in processing images with 3 channels (RGB). Has anyone already modified the code accordingly? What do I have to pay attention to?

@andife Hello, this repo also supports RGB image with 3 channels. The network is original support 3 channels input (See line 386-387 in vit_seg_modeling.py): if x.size()[1] == 1: x = x.repeat(1,3,1,1)

@Beckschen I'm trying to use this model for RGB images. I removed the random rotations (they seemed buggy for RGB images), and instead now get an error on the lines you have mentioned (386-387 in vit_seg_modeling.py). The error is as follows: RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensorimage

Did you fix this? I am also trying to repeat this for RGB images.

Hello, I had the same problem when running test.py, did you solve it?