WindVChen / DRENet

The official implementation of DRENet (Degraded Reconstruction Enhancement Network) for tiny ship detection in remote sensing Images
GNU General Public License v3.0
45 stars 6 forks source link

Error when doing inference using augmentation #10

Open ramdhan1989 opened 1 year ago

ramdhan1989 commented 1 year ago

Hi, I got error when doing inference using augment=True. the error is shown as follow. please advise

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_53796\2715512915.py in <module>
     26                 postprocess_type  = "NMS",
     27                 postprocess_match_metric = "IOU",
---> 28                 perform_standard_pred=False)
     29         result_len = result.to_coco_annotations()
     30         for pred in result_len:

~\anaconda3\envs\bird\lib\site-packages\sahi\predict.py in get_sliced_prediction(image, detection_model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio, perform_standard_pred, postprocess_type, postprocess_match_metric, postprocess_match_threshold, postprocess_class_agnostic, verbose, merge_buffer_length, auto_slice_resolution)
    243             full_shape=[
    244                 slice_image_result.original_image_height,
--> 245                 slice_image_result.original_image_width,
    246             ],
    247         )

~\anaconda3\envs\bird\lib\site-packages\sahi\predict.py in get_prediction(image, detection_model, shift_amount, full_shape, postprocess, verbose)
     89     # get prediction
     90     time_start = time.time()
---> 91     detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
     92     time_end = time.time() - time_start
     93     durations_in_seconds["prediction"] = time_end

~\AppData\Local\Temp\ipykernel_53796\30530663.py in perform_inference(self, img, image_size)
     23         with torch.no_grad():
     24             # Run model
---> 25             (out, train_out), pdg = self.model(img, augment=True)  # inference and training outputs
     26          # Run NMS
     27         prediction_result = non_max_suppression(out, conf_thres=self.confidence_threshold, iou_thres=self.iou_thres, labels=[], multi_label=True)

~\anaconda3\envs\bird\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

~\birds\DRENet\models\yolo.py in forward(self, x, augment, profile)
    121                 yi = self.forward_once(xi)[0]  # forward
    122                 # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
--> 123                 yi[..., :4] /= si  # de-scale
    124                 if fi == 2:
    125                     yi[..., 1] = img_size[0] - yi[..., 1]  # de-flip ud

TypeError: tuple indices must be integers or slices, not tuple

Thanks

WindVChen commented 1 year ago

Hi @ramdhan1989 ,

Sorry for replying late. You can modify the codeline in Line121, yolo.py from this:

yi = self.forward_once(xi)[0]  # forward

to this:

yi = self.forward_once(xi)[0][0]  # forward

The error is raised because in DRENet, we also return the degraded reconstruction image in def forward_once().

============

It seems that you want to leverage the multi-scale inference by setting augment=True. However, I'm afraid that current C3ResAtnMHSA structure may not support different input size (because of the fixed-size positional encoding).

Thus, if you want to use multi-scale inference, you may either consider to replace the C3ResAtnMHSA, or change the current C3ResAtnMHSA structure. For the structure change, maybe you can modify the current fix-size positional encoding into a adaptive one (maybe by bilinear interpolation?).

You can have a try.

ramdhan1989 commented 1 year ago

noted, thank you

ramdhan1989 commented 1 year ago

Hi @WindVChen , I am interested to modify the code to accommodate different image size. In my opinion, it would be beneficial to improve performance by applying inference using augmentation and also doing inference using original image size to capture larger objects in addition to inference on sliced images. would you mind guiding me how can I start to do modification? do I need to change only the part below?

class C3ResAtnMHSA(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, size=14, shortcut=True, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(C3ResAtnMHSA, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*[BottleneckResAtnMHSA(c_, size, shortcut=True) for _ in range(n)])
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))

Thanks Regards,

Ramdhan

WindVChen commented 1 year ago

PROBLEM

Actually you only need to change the following part:

class BottleneckResAtnMHSA(nn.Module):
    # Standard bottleneck
    def __init__(self, n_dims, size, shortcut=True):  # ch_in, ch_out, shortcut, groups, expansion
        super(BottleneckResAtnMHSA, self).__init__()

        height=size
        width=size
        self.cv1 = Conv(n_dims, n_dims//2, 1, 1)
        self.cv2 = Conv(n_dims//2, n_dims, 1, 1)
        '''MHSA PARAGRAMS'''
        self.query = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)
        self.key = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)
        self.value = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, n_dims//2, height, 1]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, n_dims//2, 1, width]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)
        self.add = shortcut

    def forward(self, x):
        x1=self.cv1(x)
        n_batch, C, width, height = x1.size()
        q = self.query(x1).view(n_batch, C, -1)
        k = self.key(x1).view(n_batch, C, -1)
        v = self.value(x1).view(n_batch, C, -1)

        content_content = torch.bmm(q.permute(0, 2, 1), k)

        content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(n_batch, C, width, height)

        return x + self.cv2(out) if self.add else self.cv2(out)

More specifically, we can find from the previous issues, that the errors (due to input resolutions) are usually come from this part:

def forward(self, x):
    ...
    content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
    content_position = torch.matmul(content_position, q)

    energy = content_content + content_position
    ...

And it is because that self.rel_h and self.rel_w is of fixed size by the settings in DRE.yaml

def __init__(self, n_dims, size, shortcut=True):
    ...
    self.rel_h = nn.Parameter(torch.randn([1, n_dims//2, height, 1]), requires_grad=True)
    self.rel_w = nn.Parameter(torch.randn([1, n_dims//2, 1, width]), requires_grad=True)
    ...

SOLUTION

Since we find the problem above, a straightforward solution is to make the following line resolution-adaptive:

content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)

My opinion is to add a codeline that interpolate self.rel_h and self.rel_w according to the input size under def forward(). Then it will support inputs of different resolutions in the inference, and in the training, there will be no need to change DRENet.yaml every time the input resolution is changed.

Since I am not sure whether this solution (somewhat brute) can make good results, I will be very glad that you can share the experimental results with me whether it is effective.

ramdhan1989 commented 1 year ago

Hi, I have printed the vector size for every step in BottleneckResAtnMHSA and C3ResAtnMHSA class inside common.py and got the summary below : <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

size | rel h | rel w | position_content | content_content -- | -- | -- | -- | -- 512×512 | (192×16×1) | (192×1×16) | (256×256) | (256×256) 1024×1024 | (192×16×1) | (192×1×16) | (256×1024) | (1024×1024) 1024×512 | (192×16×1) | (192×1×16) | (256×512) | (512×512) 640×640 | (192×16×1) | (192×1×16) | (256×400) | (400×400)

rel_h and rel_w are same for every image size. the error happened in forward method due to content_position and content_content are not matched. I am still thinking about the formula to make rel_h and rel_w become flexible. any thought?

WindVChen commented 1 year ago

You can consider interpolations on the output of the following codeline:

def forward(self, x):
    ...
    content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
    ...

From the Table you provided, after add up rel_h and rel_w, the result's size should be 192x16x16. Then, to achieve flexibility, you can interpolate it to match the size of content_content. For example, for (1024, 1024) content_content, you should interpolate 192x16x16 to 192x32x32.

ramdhan1989 commented 1 year ago

It is still not working, the error occured after several looping process in forward procedure of BottleneckResAtnMHSA. here I tried to add print command.

class BottleneckResAtnMHSA(nn.Module):
    # Standard bottleneck
    def __init__(self, n_dims, size, shortcut=True):  # ch_in, ch_out, shortcut, groups, expansion
        super(BottleneckResAtnMHSA, self).__init__()
        height=size
        width=size
        self.cv1 = Conv(n_dims, n_dims//2, 1, 1)
        self.cv2 = Conv(n_dims//2, n_dims, 1, 1)
        '''MHSA PARAGRAMS'''
        self.query = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)
        self.key = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)
        self.value = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, n_dims//2, height, 1]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, n_dims//2, 1, width]), requires_grad=True)
        print('BottleneckResAtnMHSA',size,n_dims,self.rel_h.shape,self.rel_w.shape)
        self.softmax = nn.Softmax(dim=-1)
        self.add = shortcut

    def forward(self, x):
        x1=self.cv1(x)
        n_batch, C, width, height = x1.size()
        q = self.query(x1).view(n_batch, C, -1)
        k = self.key(x1).view(n_batch, C, -1)
        v = self.value(x1).view(n_batch, C, -1)

        content_content = torch.bmm(q.permute(0, 2, 1), k)

        content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
        print('BottleneckResAtnMHSA forward',n_batch, C, width, height, content_position.shape, q.shape)

        content_position = torch.matmul(content_position, q)
        print('BottleneckResAtnMHSA forward last',content_position.shape, content_content.shape)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(n_batch, C, width, height)
        return x + self.cv2(out) if self.add else self.cv2(out)

the output as follow :

BottleneckResAtnMHSA 16 384 torch.Size([1, 192, 16, 1]) torch.Size([1, 192, 1, 16])
BottleneckResAtnMHSA 16 384 torch.Size([1, 192, 16, 1]) torch.Size([1, 192, 1, 16])
BottleneckResAtnMHSA 32 192 torch.Size([1, 96, 32, 1]) torch.Size([1, 96, 1, 32])
BottleneckResAtnMHSA 32 192 torch.Size([1, 96, 32, 1]) torch.Size([1, 96, 1, 32])
BottleneckResAtnMHSA 64 96 torch.Size([1, 48, 64, 1]) torch.Size([1, 48, 1, 64])
BottleneckResAtnMHSA 64 96 torch.Size([1, 48, 64, 1]) torch.Size([1, 48, 1, 64])
BottleneckResAtnMHSA 32 192 torch.Size([1, 96, 32, 1]) torch.Size([1, 96, 1, 32])
BottleneckResAtnMHSA 32 192 torch.Size([1, 96, 32, 1]) torch.Size([1, 96, 1, 32])
BottleneckResAtnMHSA 16 384 torch.Size([1, 192, 16, 1]) torch.Size([1, 192, 1, 16])
BottleneckResAtnMHSA 16 384 torch.Size([1, 192, 16, 1]) torch.Size([1, 192, 1, 16])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 48 64 64 torch.Size([1, 4096, 48]) torch.Size([1, 48, 4096])
BottleneckResAtnMHSA forward last torch.Size([1, 4096, 4096]) torch.Size([1, 4096, 4096])
BottleneckResAtnMHSA forward 1 48 64 64 torch.Size([1, 4096, 48]) torch.Size([1, 48, 4096])
BottleneckResAtnMHSA forward last torch.Size([1, 4096, 4096]) torch.Size([1, 4096, 4096])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 48 64 64 torch.Size([1, 4096, 48]) torch.Size([1, 48, 4096])
BottleneckResAtnMHSA forward last torch.Size([1, 4096, 4096]) torch.Size([1, 4096, 4096])
BottleneckResAtnMHSA forward 1 48 64 64 torch.Size([1, 4096, 48]) torch.Size([1, 48, 4096])
BottleneckResAtnMHSA forward last torch.Size([1, 4096, 4096]) torch.Size([1, 4096, 4096])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 96 32 32 torch.Size([1, 1024, 96]) torch.Size([1, 96, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 1024, 1024]) torch.Size([1, 1024, 1024])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
BottleneckResAtnMHSA forward 1 192 16 16 torch.Size([1, 256, 192]) torch.Size([1, 192, 256])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 256]) torch.Size([1, 256, 256])
04/13/2023 15:48:14 - INFO - utils.torch_utils -   Model Summary: 462 layers, 15523082 parameters, 15523082 gradients, 46.7 GFLOPS
04/13/2023 15:48:14 - INFO - models.yolo -   
BottleneckResAtnMHSA forward 1 192 32 32 torch.Size([1, 256, 192]) torch.Size([1, 192, 1024])
BottleneckResAtnMHSA forward last torch.Size([1, 256, 1024]) torch.Size([1, 1024, 1024])

the error is RuntimeError: The size of tensor a (1024) must match the size of tensor b (256) at non-singleton dimension 1 please advise,

Following your suggestion, I interpolated the result self.rel_h + self.rel_w but still produce an error since the first loop. please advise thank you

ramdhan1989 commented 1 year ago

Hi, I did a modification and it is successful to run with different size of image.

class BottleneckResAtnMHSA(nn.Module):
    # Standard bottleneck
    def __init__(self, n_dims, size, shortcut=True):  # ch_in, ch_out, shortcut, groups, expansion
        super(BottleneckResAtnMHSA, self).__init__()
        height=size
        width=size
        self.cv1 = Conv(n_dims, n_dims//2, 1, 1)
        self.cv2 = Conv(n_dims//2, n_dims, 1, 1)
        '''MHSA PARAGRAMS'''
        self.query = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)
        self.key = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)
        self.value = nn.Conv2d(n_dims//2, n_dims//2, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, n_dims//2, height, 1]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, n_dims//2, 1, width]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)
        self.add = shortcut

    def forward(self, x):
        x1=self.cv1(x)
        n_batch, C, width, height = x1.size()
        q = self.query(x1).view(n_batch, C, -1)
        k = self.key(x1).view(n_batch, C, -1)
        v = self.value(x1).view(n_batch, C, -1)

        content_content = torch.bmm(q.permute(0, 2, 1), k)
        adjuster = sqrt(content_content.shape[1]/(self.rel_h.shape[2]*self.rel_w.shape[3]))

        add = torch.nn.functional.interpolate(self.rel_h + self.rel_w, size=(int(self.rel_h.shape[2]*adjuster), int(self.rel_w.shape[3]*adjuster)), mode='bilinear', align_corners=False)
        content_position = ((add)).view(1, C, -1).permute(0, 2, 1)

        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(n_batch, C, width, height)
        return x + self.cv2(out) if self.add else self.cv2(out)

However, when I set augment=True, I got the error below:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_21980\2069210895.py in <module>
     28                 postprocess_type  = "NMS",
     29                 postprocess_match_metric = "IOU",
---> 30                 perform_standard_pred=False)
     31         result_len = result.to_coco_annotations()
     32         for pred in result_len:

~\anaconda3\envs\bird\lib\site-packages\sahi\predict.py in get_sliced_prediction(image, detection_model, slice_height, slice_width, overlap_height_ratio, overlap_width_ratio, perform_standard_pred, postprocess_type, postprocess_match_metric, postprocess_match_threshold, postprocess_class_agnostic, verbose, merge_buffer_length, auto_slice_resolution)
    243             full_shape=[
    244                 slice_image_result.original_image_height,
--> 245                 slice_image_result.original_image_width,
    246             ],
    247         )

~\anaconda3\envs\bird\lib\site-packages\sahi\predict.py in get_prediction(image, detection_model, shift_amount, full_shape, postprocess, verbose)
     89     # get prediction
     90     time_start = time.time()
---> 91     detection_model.perform_inference(np.ascontiguousarray(image_as_pil))
     92     time_end = time.time() - time_start
     93     durations_in_seconds["prediction"] = time_end

~\AppData\Local\Temp\ipykernel_21980\114499772.py in perform_inference(self, img, image_size)
     35         with torch.no_grad():
     36             # Run model
---> 37             (out, train_out), pdg = self.model(img, augment=True)  # inference and training outputs
     38          # Run NMS
     39         prediction_result = non_max_suppression(out, conf_thres=self.confidence_threshold, iou_thres=self.iou_thres, labels=[], multi_label=True)

~\anaconda3\envs\bird\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

~\codelab\DRENet\DRENet\models\yolo.py in forward(self, x, augment, profile)
    121                 yi = self.forward_once(xi)[0]  # forward
    122                 # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
--> 123                 yi[..., :4] /= si  # de-scale
    124                 if fi == 2:
    125                     yi[..., 1] = img_size[0] - yi[..., 1]  # de-flip ud

TypeError: tuple indices must be integers or slices, not tuple

do you have idea to solve this? thank you

WindVChen commented 1 year ago

Modifications:

Class BottleneckResAtnMHSA(nn.Module):
        ...
        # content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
        content_position = (self.rel_h + self.rel_w)
        content_position = nn.functional.interpolate(content_position, (int(content_content.shape[-1]**0.5), int(content_content.shape[-1]**0.5)), mode='bilinear')
        content_position = content_position.view(1, C, -1).permute(0, 2, 1)
        ...

And the error seems the same as here?

ramdhan1989 commented 1 year ago

Modifications:

Class BottleneckResAtnMHSA(nn.Module):
        ...
        # content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)
        content_position = (self.rel_h + self.rel_w)
        content_position = nn.functional.interpolate(content_position, (int(content_content.shape[-1]**0.5), int(content_content.shape[-1]**0.5)), mode='bilinear')
        content_position = content_position.view(1, C, -1).permute(0, 2, 1)
        ...

And the error seems the same as here?

Noted. it is working now. Based on my experiment, using TTA for inference, it didn't improve the performance in my case. based on this link you mentioned possibility to train the model using different size of image. I think there is a potential error coming from dataloader of different size of images, isn't it?

WindVChen commented 1 year ago

You're right. The saying "train the model using different size of image" may just mean that we don't need to recalculate the parameters in DRENet.yaml for input images with sizes other than 512, if the program can adapt to different input sizes.

Actually, I'm curious about whether such an adaptive operation will hurt the performance? For example, how is the result on 640x640 input different when we use adaptive code (parameters initialized for 512 input resolution) and when we manually recompute parameters for 640 resolution in DRENet.yaml? Have it been tried?