HUANGLIZI / LViT

[IEEE Transactions on Medical Imaging/TMI] This repo is the official implementation of "LViT: Language meets Vision Transformer in Medical Image Segmentation"
MIT License
298 stars 26 forks source link

Question about how to test the UNet model ? #18

Closed haideralimughal closed 1 year ago

haideralimughal commented 1 year ago

Can this setup be taken into consideration for UNet series models? If not, how can a new UNet series configuration be created? or is this only applicable to transformers?

##########################################################################

CTrans configs

########################################################################## def get_CTranS_config(): config = ml_collections.ConfigDict() config.transformer = ml_collections.ConfigDict() config.KV_size = 960 # KV_size = Q1 + Q2 + Q3 + Q4 config.transformer.num_heads = 4 config.transformer.num_layers = 4 config.expand_ratio = 4 # MLP channel dimension expand ratio config.transformer.embeddings_dropout_rate = 0.1 config.transformer.attention_dropout_rate = 0.1 config.transformer.dropout_rate = 0 config.patch_sizes = [16, 8, 4, 2] config.base_channel = 24 # base channel of U-Net config.n_classes = 1 return config

haideralimughal commented 1 year ago

Hi, I am trying to get the UNet model results with this configurations in train.py file? kindly give me some suggestions elif model_type == 'UNet': config_vit = config.get_CTranS_config() model = UNet(config_vit, n_classes=config.n_labels) but getting this error (UNet Traceback (most recent call last): File "C:\Users\lenovo\PycharmProjects\LViT-main\train_model.py", line 217, in model = main_loop(model_type=config.model_name, tensorboard=True) File "C:\Users\lenovo\PycharmProjects\LViT-main\train_model.py", line 128, in main_loop model = UNet(config_vit, n_classes=config.n_labels) File "C:\Users\lenovo\PycharmProjects\LViT-main\nets\UNet.py", line 73, in init self.inc = ConvBatchNorm(n_channels, in_channels) File "C:\Users\lenovo\PycharmProjects\LViT-main\nets\UNet.py", line 25, in init kernel_size=3, padding=1) File "C:\Users\lenovo\anaconda3\envs\trans\lib\site-packages\torch\nn\modules\conv.py", line 412, in init False, _pair(0), groups, bias, padding_mode) File "C:\Users\lenovo\anaconda3\envs\trans\lib\site-packages\torch\nn\modules\conv.py", line 50, in init if in_channels % groups != 0: TypeError: unsupported operand type(s) for %: 'ConfigDict' and 'int' )

HUANGLIZI commented 1 year ago

Unfortunately, you should follow the parameter settings related to UNet, rather than simply migrating hyperparameter settings of ViT to UNet. If you have further question, please email me.