megvii-research / TreeEnergyLoss

[CVPR2022] Tree Energy Loss: Towards Sparsely Annotated Semantic Segmentation
Other
103 stars 8 forks source link

Encounter nan after using function 'TreeFilter2D' #1

Closed WinterPan2017 closed 2 years ago

WinterPan2017 commented 2 years ago

Hello, Thanks for your excellent work. I tried to generate a refined mask using low-level features, but I got nan after using TreeFilter2D. I use the same image as the low-level features and prediction for simplicity. My code is as below.

import torch
from kernels.lib_tree_filter.modules.tree_filter import MinimumSpanningTree
from kernels.lib_tree_filter.modules.tree_filter import TreeFilter2D

mst_layers = MinimumSpanningTree(TreeFilter2D.norm2_distance)
tree_filter_layers = TreeFilter2D(groups=1, sigma=0.002)
image = torch.tensor(
    [[[   
        [0,0,0],
        [0,0.7,0],
        [0,0,0],
    ]]], dtype=torch.float)
tree = mst_layers(image)
print(tree)
AS = tree_filter_layers(feature_in=image, embed_in=image, tree=tree) 
print(AS)
output:
tensor([[[0, 3],
         [0, 1],
         [2, 5],
         [1, 4],
         [3, 6],
         [6, 7],
         [5, 8],
         [1, 2]]], dtype=torch.int32)
tensor([[[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]]]])

Is there something wrong with my usage? I'm looking forward for your reply.

liangzhiyuanCV commented 2 years ago

Thanks for your attention. The inputs of TreeFilter2D should be CUDA tensor types. Please just put the image tensor to the CUDA device first.

image = torch.tensor(
            [[[
                [0, 0, 0],
                [0, 0.7, 0],
                [0, 0, 0],
            ]]], dtype=torch.float).cuda()
WinterPan2017 commented 2 years ago

Got it, Thanks a lot!