lorenmt / mtan

The implementation of "End-to-End Multi-Task Learning with Attention" [CVPR 2019].
https://shikun.io/projects/multi-task-attention-network
MIT License
665 stars 108 forks source link

Training with my own dataset #44

Closed Njuod closed 3 years ago

Njuod commented 3 years ago

Hi lorenmt!

Thank you for sharing your code. I would like to train the model_segnet_mtan.py with my own dataset. I have prepared the data to have the same size (500x500) and to be in the same format as in create_dataset.py. However, when I've started the training I got the following error:

Parameter Space: ABS: 44229076.0, REL: 1.7705
LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR | NORMAL_LOSS MEAN MED <11.25 <22.5 <30
Standard training strategy without data augmentation.
image: torch.Size([3, 500, 500])
semantic: torch.Size([500, 500])
depth: torch.Size([500, 500])
normal: torch.Size([500, 500])
Traceback (most recent call last):
  File "model_segnet_mtan.py", line 222, in <module>
    200)
  File "/home/models/mtan/im2im_pred/utils.py", line 158, in multi_task_trainer
    train_pred, logsigma = multi_task_model(train_data)
  File "/home/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "model_segnet_mtan.py", line 143, in forward
    g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
  File "/home/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/miniconda3/lib/python3.7/site-packages/torch/nn/modules/pooling.py", line 356, in forward
    self.padding, output_size)
  File "/home/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 598, in max_unpool2d
    return torch._C._nn.max_unpool2d(input, indices, output_size)
RuntimeError: Shape of input must match shape of indices

It seems the error because of the input shape. Do I need to resize the images or can the model accept images with size 500x500?, how can I modify the indices shape? Could you help me to solve the error, I'm new to PyTorch?.

lorenmt commented 3 years ago

Hello. I believe that is because your image size 500 is not divisible by 16, so when you downsampled the images, the pixels shifted a bit, the saved sampling indices would not match the size of upsampled features.

An easy fix is to resize the input images to the size that is divisible by 16, like 256 as an example.

Let me know whether that solves your issue.

Njuod commented 3 years ago

Hi! Thank you for your quick answer. I've padded the images to 512 to be divisible by 16 and it seems that solves this issue. Thanks!

Njuod commented 3 years ago

Hi again :),

My dataset is very large which caused the following error:

RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 5.79 GiB total capacity; 4.48 GiB already allocated; 10.31 MiB free; 25.81 MiB cached)

I see the model uses only 1 GPU. So, I'm trying to train the model with 2 GPUs by wrapping the model in nn.DataParallel as follow:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.nn.DataParallel(SegNet())
SegNet_MTAN = model.to(device)

But I got this error:

Standard training strategy without data augmentation.
Traceback (most recent call last):
  File "model_segnet_mtan.py", line 230, in <module>
    5)
  File "/home/models/mtan/im2im_pred/utils.py", line 152, in multi_task_trainer
    conf_mat = ConfMatrix(multi_task_model.class_nb)
  File "/home/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
    type(self).__name__, name))
AttributeError: 'DataParallel' object has no attribute 'class_nb'

could you help me to solve it?

lorenmt commented 3 years ago

I think after you wrapped the model in a multi-gpu setting, the model class becomes model.module. So you need to access the class_nb by multi_task_model.module.class_nb.

I would suggest to check out pytorch documentation for these questions.

Best,

Njuod commented 3 years ago

Thanks, that solves the error.

I'm using the model to do three semantic tasks with a different number of classes. So, instead of semantic, depth and normal, I have semantic label, semantic label_2 and semantic label_3. I've modified themodel_segnet_mtan.py and utils.py files, but when I'm using the semantic label_2 I got nan value:

Parameter Space: ABS: 44234471.0, REL: 1.7707
LOSS FORMAT: SEMANTIC-1_LOSS MEAN_IOU PIX_ACC | SEMANTIC-1_LOSS MEAN_IOU PIX_ACC | SEMANTIC-2_LOSS MEAN_IOU PIX_ACC
Standard training strategy without data augmentation.
Epoch: 0000 | TRAIN: 1.8811 0.0339 0.5995 | 1.5766 0.0374 0.7284 | 2.6665 0.0122 0.6375 ||TEST: 1.0568 0.0376 0.7893 | 1.0827 0.0376 0.7893 | 1.5529 nan 0.7953  
Epoch: 0001 | TRAIN: 0.9069 0.0390 0.8198 | 0.9206 0.0390 0.8198 | 1.2247 0.0142 0.8263 ||TEST: 1.0012 0.0376 0.7893 | 1.0153 0.0376 0.7893 | 1.3278 nan 0.7953  
Epoch: 0002 | TRAIN: 0.8436 0.0390 0.8198 | 0.8585 0.0390 0.8198 | 1.0681 0.0142 0.8263 ||TEST: 0.9553 0.0376 0.7893 | 0.9683 0.0376 0.7893 | 1.2469 nan 0.7953  
Epoch: 0003 | TRAIN: 0.8302 0.0390 0.8198 | 0.8416 0.0390 0.8198 | 1.0364 0.0142 0.8263 ||TEST: 0.9744 0.0376 0.7891 | 0.9848 0.0377 0.7884 | 1.2454 nan 0.7953  
Epoch: 0004 | TRAIN: 0.8391 0.0391 0.8198 | 0.8458 0.0392 0.8197 | 1.0312 0.0142 0.8263 ||TEST: 1.9153 0.0377 0.7890 | 1.6111 0.0386 0.7873 | 2.0755 nan 0.7953

While when I'm using semantic label_3 I got the following error:

Parameter Space: ABS: 44237786.0, REL: 1.7709
LOSS FORMAT: SEMANTIC-1_LOSS MEAN_IOU PIX_ACC | SEMANTIC-1_LOSS MEAN_IOU PIX_ACC | SEMANTIC-3_LOSS MEAN_IOU PIX_ACC
Standard training strategy without data augmentation.
Traceback (most recent call last):
  File "model_segnet_mtan.py", line 230, in <module>
    10)
  File "/mtan/im2im_pred/utils.py", line 165, in multi_task_trainer
    model_fit(train_pred[2], train_label3, 'semantic')]
  File "/mtan/im2im_pred/utils.py", line 22, in model_fit
    loss = F.nll_loss(x_pred, x_output, ignore_index=-1)
  File "/home/.local/lib/python3.7/site-packages/torch/nn/functional.py", line 1873, in nll_loss
    ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /pytorch/aten/src/THNN/generic/SpatialClassNLLCriterion.c:109

Could you give me some advice, thank you.

lorenmt commented 3 years ago

For 3 segmentation tasks, you probably need to make sure the prediction classes for each classifer head is correct... It looks like you probably predict a larger number of classes compared to a pre-defined ground-truth classes. And you need to define confusion matrices for each semantic class so the miou score would be computed correctly.

Njuod commented 3 years ago

Thanks for your reply!

Can you clarify this part more "And you need to define confusion matrices for each semantic class so the miou score would be computed correctly."

This is how I modified utils.py

def multi_task_trainer(train_loader, test_loader, multi_task_model, device, optimizer, scheduler, opt, total_epoch=200):
    train_batch = len(train_loader)
    test_batch = len(test_loader)
    T = opt.temp
    avg_cost = np.zeros([total_epoch, 19], dtype=np.float32)
    lambda_weight = np.ones([3, total_epoch])
    for index in range(total_epoch):
        cost = np.zeros(19, dtype=np.float32)

        # apply Dynamic Weight Average
        if opt.weight == 'dwa':
            if index == 0 or index == 1:
                lambda_weight[:, index] = 1.0
            else:
                w_1 = avg_cost[index - 1, 0] / avg_cost[index - 2, 0]
                w_2 = avg_cost[index - 1, 3] / avg_cost[index - 2, 3]
                w_3 = avg_cost[index - 1, 6] / avg_cost[index - 2, 6]
                lambda_weight[0, index] = 3 * np.exp(w_1 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
                lambda_weight[1, index] = 3 * np.exp(w_2 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))
                lambda_weight[2, index] = 3 * np.exp(w_3 / T) / (np.exp(w_1 / T) + np.exp(w_2 / T) + np.exp(w_3 / T))

        # iteration for all batches
        multi_task_model.train()
        train_dataset = iter(train_loader)
        conf_mat = ConfMatrix(multi_task_model.module.class_nb)
        conf_mat2 = ConfMatrix(multi_task_model.module.class_nb2)
        conf_mat3 = ConfMatrix(multi_task_model.module.class_nb3)
        for k in range(train_batch):
            train_data, train_label, train_label2, train_label3 = train_dataset.next()
            train_data, train_label = train_data.to(device), train_label.long().to(device)
            train_label2, train_label3 = train_label2.long().to(device), train_label3.long().to(device)

            train_pred, logsigma = multi_task_model(train_data)

            optimizer.zero_grad()
            train_loss = [model_fit(train_pred[0], train_label, 'semantic'),
                          model_fit(train_pred[1], train_label2, 'semantic'),
                          model_fit(train_pred[2], train_label3, 'semantic')]

            if opt.weight == 'equal' or opt.weight == 'dwa':
                loss = sum([lambda_weight[i, index] * train_loss[i] for i in range(3)])
            else:
                loss = sum(1 / (2 * torch.exp(logsigma[i])) * train_loss[i] + logsigma[i] / 2 for i in range(3))

            loss.backward()
            optimizer.step()

            # accumulate label prediction for every pixel in training images
            conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten())
            conf_mat2.update(train_pred[1].argmax(1).flatten(), train_label2.flatten())
            conf_mat3.update(train_pred[2].argmax(1).flatten(), train_label3.flatten())

            cost[0] = train_loss[0].item()
            cost[3] = train_loss[1].item()
            cost[6] = train_loss[2].item()
            avg_cost[index, :9] += cost[:9] / train_batch

        # compute mIoU and acc
        avg_cost[index, 1:3] = conf_mat.get_metrics()
        avg_cost[index, 4:6] = conf_mat2.get_metrics()
        avg_cost[index, 7:9] = conf_mat3.get_metrics()

        # evaluating test data
        multi_task_model.eval()
        conf_mat = ConfMatrix(multi_task_model.module.class_nb)
        conf_mat2 = ConfMatrix(multi_task_model.module.class_nb2)
        conf_mat3 = ConfMatrix(multi_task_model.module.class_nb3)
        with torch.no_grad():  # operations inside don't track history
            test_dataset = iter(test_loader)
            for k in range(test_batch):
                test_data, test_label, test_label2, test_label3 = test_dataset.next()
                test_data, test_label = test_data.to(device), test_label.long().to(device)
                test_label2, test_label3 = test_label2.long().to(device), test_label3.long().to(device)

                test_pred, _ = multi_task_model(test_data)
                test_loss = [model_fit(test_pred[0], test_label, 'semantic'),
                             model_fit(test_pred[1], test_label2, 'semantic'),
                             model_fit(test_pred[2], test_label3, 'semantic')]

                conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten())
                conf_mat2.update(test_pred[1].argmax(1).flatten(), test_label2.flatten())
                conf_mat3.update(test_pred[2].argmax(1).flatten(), test_label3.flatten())

                cost[9] = test_loss[0].item()
                cost[12] = test_loss[1].item()
                cost[15] = test_loss[2].item()
                avg_cost[index, 9:] += cost[9:] / test_batch

            # compute mIoU and acc
            avg_cost[index, 10:12] = conf_mat.get_metrics()
            avg_cost[index, 13:15] = conf_mat2.get_metrics()
            avg_cost[index, 16:18] = conf_mat3.get_metrics()

        scheduler.step()
        print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} ||'
            'TEST: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f}  '
            .format(index, avg_cost[index, 0], avg_cost[index, 1], avg_cost[index, 2], avg_cost[index, 3],
                    avg_cost[index, 4], avg_cost[index, 5], avg_cost[index, 6], avg_cost[index, 7], avg_cost[index, 8],
                    avg_cost[index, 9], avg_cost[index, 10], avg_cost[index, 11], avg_cost[index, 12], avg_cost[index, 13],
                    avg_cost[index, 14], avg_cost[index, 15], avg_cost[index, 16], avg_cost[index, 17], avg_cost[index, 18]))

Is that what you meant or do you mean to define confusion matrices for each task?

lorenmt commented 3 years ago

Yes. This modification looks correct to me.

Njuod commented 3 years ago

Hello, sorry to bother you. I still get nan value, even when I'm using model_segnet_single.py. Do you think it is related to the linear layer? the number of classes that I have is 58.

lorenmt commented 3 years ago

Maybe try to remove F.log_softmax in prediction and change the loss function to F.cross_entropy(pred, gt).

If you still have the issue, then it's definitely from your dataset or your implementation.

Njuod commented 3 years ago

OK, I modify it like the following:

        t1_pred = F.cross_entropy(self.pred_task1(atten_decoder[0][-1][-1]), GT)
        t2_pred = F.cross_entropy(self.pred_task2(atten_decoder[1][-1][-1]), Gt)
        t3_pred = F.cross_entropy(self.pred_task3(atten_decoder[2][-1][-1]), GT)

        return [t1_pred, t2_pred, t3_pred], self.logsigma

Sorry, Could you let me know what should I pass to the GTs?

lorenmt commented 3 years ago

Sorry What I meant is, do: t1_pred = self.pred_task1(atten_decoder[0][-1][-1]) t2_pred = self.pred_task2(atten_decoder[1][-1][-1]) ...

and modify the loss function in utils.py: https://github.com/lorenmt/mtan/blob/268c5c1796bf972d1e0850dcf1e2f2e6598cc589/im2im_pred/utils.py#L22 into F.cross_entropy(x_pred, x_output, ignore_index=-1)

-1 represents the index you wish to ignore. You probably need to modify that as well according to the configuration in your dataset.

Njuod commented 3 years ago

Oh!!! That was a silly mistake! I thought that my dataset is 0 indexed but I found that it is 1 indexed!! So sorry to take your time. Thank you very much for the help!

Njuod commented 3 years ago

Hi, sorry to bother you.

I would like to modify the mtan model to do two tasks instead of three. Depending on the stan model I modify the code as the following:

  1. reduce j from 3 to 2:
        for j in range(2):
            if j < 2:
                self.encoder_att.append(nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])]))
                self.decoder_att.append(nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])]))
            for i in range(4):
                self.encoder_att[j].append(self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]]))
                self.decoder_att[j].append(self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]]))
  2. reduce i from 3 to 2:
        # define task dependent attention module
        for i in range(2):
            for j in range(5):
                if j == 0:
  3. logsigma self.logsigma = nn.Parameter(torch.FloatTensor([-0.5, -0.5]))
  4. makes the mtan model returns only 3 values return [t1_pred, t2_pred], self.logsigma

So, I'm wondering if I did that correctly, especially for steps 1 and 3.

also, why is the total parameters divided by 24981069?

Thanks!

lorenmt commented 3 years ago

# Multi-task Attention Network:
class MTANSegNet(nn.Module):
    def __init__(self, tasks=['segmentation', 'depth', 'normal'], 
                 out_channels={'segmentation': 13, 'depth': 1, 'normal': 3}):
        super(MTANSegNet, self).__init__()
        # initialise network parameters
        filter = [64, 128, 256, 512, 512]
        self.tasks = tasks
        self.num_tasks = len(tasks)

        # define encoder decoder layers
        self.encoder_block = nn.ModuleList([self.conv_layer([3, filter[0]])])
        self.decoder_block = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])

        for i in range(4):
            self.encoder_block.append(self.conv_layer([filter[i], filter[i + 1]]))
            self.decoder_block.append(self.conv_layer([filter[i + 1], filter[i]]))

        # define convolution layer
        self.conv_block_enc = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])
        self.conv_block_dec = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])

        for i in range(4):
            if i == 0:
                self.conv_block_enc.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                self.conv_block_dec.append(self.conv_layer([filter[i], filter[i]]))
            else:
                self.conv_block_enc.append(nn.Sequential(self.conv_layer([filter[i + 1], filter[i + 1]]),
                                                         self.conv_layer([filter[i + 1], filter[i + 1]])))
                self.conv_block_dec.append(nn.Sequential(self.conv_layer([filter[i], filter[i]]),
                                                         self.conv_layer([filter[i], filter[i]])))

        # define task attention layers
        self.encoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])])])
        self.decoder_att = nn.ModuleList([nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])])])
        self.encoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[1]])])
        self.decoder_block_att = nn.ModuleList([self.conv_layer([filter[0], filter[0]])])

        for j in range(self.num_tasks):
            if j < (self.num_tasks - 1):
                self.encoder_att.append(nn.ModuleList([self.att_layer([filter[0], filter[0], filter[0]])]))
                self.decoder_att.append(nn.ModuleList([self.att_layer([2 * filter[0], filter[0], filter[0]])]))
            for i in range(4):
                self.encoder_att[j].append(self.att_layer([2 * filter[i + 1], filter[i + 1], filter[i + 1]]))
                self.decoder_att[j].append(self.att_layer([filter[i + 1] + filter[i], filter[i], filter[i]]))

        for i in range(4):
            if i < 3:
                self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 2]]))
                self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i]]))
            else:
                self.encoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))
                self.decoder_block_att.append(self.conv_layer([filter[i + 1], filter[i + 1]]))

        self.pred_task = nn.ModuleList([self.conv_layer([filter[0], out_channels[t]], pred=True) for t in tasks])

        # define pooling and unpooling functions
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)

    def conv_layer(self, channel, pred=False):
        if not pred:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=3, padding=1),
                nn.BatchNorm2d(num_features=channel[1]),
                nn.ReLU(inplace=True),
            )
        else:
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=channel[0], out_channels=channel[0], kernel_size=3, padding=1),
                nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0),
            )
        return conv_block

    def att_layer(self, channel):
        att_block = nn.Sequential(
            nn.Conv2d(in_channels=channel[0], out_channels=channel[1], kernel_size=1, padding=0),
            nn.BatchNorm2d(channel[1]),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=channel[1], out_channels=channel[2], kernel_size=1, padding=0),
            nn.BatchNorm2d(channel[2]),
            nn.Sigmoid(),
        )
        return att_block

    def forward(self, x):
        g_encoder, g_decoder, g_maxpool, g_upsampl, indices = ([0] * 5 for _ in range(5))
        for i in range(5):
            g_encoder[i], g_decoder[-i - 1] = ([0] * 2 for _ in range(2))

        # define attention list for tasks
        atten_encoder, atten_decoder = ([0] * self.num_tasks for _ in range(2))
        for i in range(self.num_tasks):
            atten_encoder[i], atten_decoder[i] = ([0] * 5 for _ in range(2))

        for i in range(self.num_tasks):
            for j in range(5):
                atten_encoder[i][j], atten_decoder[i][j] = ([0] * 3 for _ in range(2))

        # define global shared network
        for i in range(5):
            if i == 0:
                g_encoder[i][0] = self.encoder_block[i](x)
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])
            else:
                g_encoder[i][0] = self.encoder_block[i](g_maxpool[i - 1])
                g_encoder[i][1] = self.conv_block_enc[i](g_encoder[i][0])
                g_maxpool[i], indices[i] = self.down_sampling(g_encoder[i][1])

        for i in range(5):
            if i == 0:
                g_upsampl[i] = self.up_sampling(g_maxpool[-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])
            else:
                g_upsampl[i] = self.up_sampling(g_decoder[i - 1][-1], indices[-i - 1])
                g_decoder[i][0] = self.decoder_block[-i - 1](g_upsampl[i])
                g_decoder[i][1] = self.conv_block_dec[-i - 1](g_decoder[i][0])

        # define task dependent attention module
        for i in range(self.num_tasks):
            for j in range(5):
                if j == 0:
                    atten_encoder[i][j][0] = self.encoder_att[i][j](g_encoder[j][0])
                    atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
                    atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1])
                    atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2)
                else:
                    atten_encoder[i][j][0] = self.encoder_att[i][j](torch.cat((g_encoder[j][0], atten_encoder[i][j - 1][2]), dim=1))
                    atten_encoder[i][j][1] = (atten_encoder[i][j][0]) * g_encoder[j][1]
                    atten_encoder[i][j][2] = self.encoder_block_att[j](atten_encoder[i][j][1])
                    atten_encoder[i][j][2] = F.max_pool2d(atten_encoder[i][j][2], kernel_size=2, stride=2)

            for j in range(5):
                if j == 0:
                    atten_decoder[i][j][0] = F.interpolate(atten_encoder[i][-1][-1], scale_factor=2, mode='bilinear', align_corners=True)
                    atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0])
                    atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1))
                    atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]
                else:
                    atten_decoder[i][j][0] = F.interpolate(atten_decoder[i][j - 1][2], scale_factor=2, mode='bilinear', align_corners=True)
                    atten_decoder[i][j][0] = self.decoder_block_att[-j - 1](atten_decoder[i][j][0])
                    atten_decoder[i][j][1] = self.decoder_att[i][-j - 1](torch.cat((g_upsampl[j], atten_decoder[i][j][0]), dim=1))
                    atten_decoder[i][j][2] = (atten_decoder[i][j][1]) * g_decoder[j][-1]

        # define task prediction layers
        out = [0 for _ in self.tasks]
        for i, t in enumerate(self.tasks):
            out[i] = self.pred_task[i](atten_decoder[i][-1][-1])
            if t == 'normal':
                out[i] = out[i] / torch.norm(out[i], p=2, dim=1, keepdim=True)
        return out

You can follow this implementation which should be cleaner in terms of scaling tasks.

And 24981069 represents the parameter size of single-task learning. So it's easier to compare the parameter size relatively in different networks.

Njuod commented 3 years ago

Thank you, your reply is much appreciated.