Closed haideralimughal closed 1 year ago
Hi, I am trying to get the UNet model results with this configurations in train.py file? kindly give me some suggestions
elif model_type == 'UNet':
config_vit = config.get_CTranS_config()
model = UNet(config_vit, n_classes=config.n_labels)
but getting this error (UNet
Traceback (most recent call last):
File "C:\Users\lenovo\PycharmProjects\LViT-main\train_model.py", line 217, in
Unfortunately, you should follow the parameter settings related to UNet, rather than simply migrating hyperparameter settings of ViT to UNet. If you have further question, please email me.
Can this setup be taken into consideration for UNet series models? If not, how can a new UNet series configuration be created? or is this only applicable to transformers?
##########################################################################
CTrans configs
########################################################################## def get_CTranS_config(): config = ml_collections.ConfigDict() config.transformer = ml_collections.ConfigDict() config.KV_size = 960 # KV_size = Q1 + Q2 + Q3 + Q4 config.transformer.num_heads = 4 config.transformer.num_layers = 4 config.expand_ratio = 4 # MLP channel dimension expand ratio config.transformer.embeddings_dropout_rate = 0.1 config.transformer.attention_dropout_rate = 0.1 config.transformer.dropout_rate = 0 config.patch_sizes = [16, 8, 4, 2] config.base_channel = 24 # base channel of U-Net config.n_classes = 1 return config