AWF-GAUG / TreeCrownDelineation

Individual tree crown delineation in optical remote sensing images
MIT License
88 stars 18 forks source link

The inference fails, and the error message seems to be that the weight parameters do not match the model #11

Open panbo-bridge opened 9 months ago

panbo-bridge commented 9 months ago

Here are my input parameters "-i","input/2021-09-02-sbl-z3-rgb-cog.tif","-o","result","-m","models/Unet-resnet18_epochs=209_lr=0.0001_width=224_bs=32_divby=255_custom_color_augs_k=0_jitted.pt" This is my error message 发生异常: RuntimeError (note: full exception trace is shown but execution is paused at: _run_module_as_main) The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/torch/treecrowndelineation/model/tcd_model.py", line 12, in forward dist_model = self.dist_model seg_model = self.seg_model _0 = (seg_model).forward(img, )


    _1 = [_0, (dist_model).forward(_0, img, )]
    return torch.cat(_1, 1)
  File "code/__torch__/treecrowndelineation/model/segmentation_model.py", line 11, in forward
    img: Tensor) -> Tensor:
    model = self.model
    return (model).forward(img, )
            ~~~~~~~~~~~~~~ <--- HERE
  File "code/__torch__/segmentation_models_pytorch/unet/model.py", line 14, in forward
    decoder = self.decoder
    encoder = self.encoder
    _0, _1, _2, _3, _4, = (encoder).forward(img, )
                           ~~~~~~~~~~~~~~~~ <--- HERE
    _5 = (decoder).forward(_0, _1, _2, _3, _4, )
    return (segmentation_head).forward(_5, )
  File "code/__torch__/segmentation_models_pytorch/encoders/resnet.py", line 24, in forward
    bn1 = self.bn1
    conv1 = self.conv1
    _0 = (bn1).forward((conv1).forward(img, ), )
                        ~~~~~~~~~~~~~~ <--- HERE
    _1 = (relu).forward(_0, )
    _2 = (layer1).forward((maxpool).forward(_1, ), )
  File "code/__torch__/torch/nn/modules/conv.py", line 10, in forward
    img: Tensor) -> Tensor:
    weight = self.weight
    input = torch._convolution(img, weight, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1, False, False, True, True)
            ~~~~~~~~~~~~~~~~~~ <--- HERE
    return input

Traceback of TorchScript, original code (most recent call last):
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/conv.py(442): _conv_forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/conv.py(446): forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/container.py(141): forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/segmentation_models_pytorch/encoders/resnet.py(62): forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/segmentation_models_pytorch/base/model.py(15): forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/treecrowndelineation-0.1.0-py3.8.egg/treecrowndelineation/model/segmentation_model.py(51): forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/treecrowndelineation-0.1.0-py3.8.egg/treecrowndelineation/model/tcd_model.py(38): forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/jit/_trace.py(958): trace_module
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/jit/_trace.py(741): trace
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py(1778): to_torchscript
/home/max/.conda/envs/dr2/lib/python3.8/site-packages/torch/autograd/grad_mode.py(28): decorate_context
training_BKG_k_fold.py(160): <module>
RuntimeError: Given groups=1, weight of size [64, 5, 7, 7], expected input[16, 4, 256, 256] to have 5 channels, but got 4 channels instead
**How do I fix this, can I provide the running instance model, parameters, and input data?**
maxfreu commented 9 months ago

The error says that the model received 4 channels, but expects 5. The readme.md file from the zip file containing the weights mentions that the model was trained RGBI images plus NDVI (=5 channels) and at the bottom it mentions an example inference call that looks like this:

./inference.py -i input_file.tif -o output_file.sqlite -m ~/models/*jitted* -w 224 --ndvi --min-dist 10 --sigma 2 -l 0.4 -b 0.05 -s 0.21 --div 255 --rescale-ndvi --save-prediction ~/intermediate_output --sigmoid

This does a few things: 1) divides the 8 bit RGBI input image by 255 2) it computes the NDVI on the fly 3) it rescales the NDVI to 0..1 4) it applies sigmoid to the network output 5) sets the input image size to 224 6) sets the minimum distance for tree crowns to 10 pixels, which affects over / undersegmentation 7) sets some thresholds for crown detection that control the size of resulting crowns (l and b parameters, look at the help via inference.py -h) 8) saves the network output from which the crown polygons are computed (useful for debugging) 9) sets the polygon simplification to 0.21m, so that the resulting files are smaller; delineating single pixels is most of the time overkill

You could further specify the index of your red and nir band for NDVI computation via --red and --nir.

Hope that helps!

panbo-bridge commented 9 months ago

Thanks to the fact that I have been able to run the program, but probably because I don't have RGBI images, I didn't get the desired results using RGBA

maxfreu commented 9 months ago

Hmmm, currently I haven't trained any model on RGB only and not time to retrain one. What you could do however, is to load the model and remove the fourth and fifth channel from the first conv layer weights. Then it works with RGB, but I don't know to which quality. Feel free to close the issue if that fixes your problems.

TobyZhouWei commented 8 months ago

Hi, my images is in png format. The annotations are organized in ms coco format and stored in json files. How can I preprocess the data correctly?

maxfreu commented 8 months ago

Would make more sense to open another issue for this, but currently coco format is not supported because the rasterization algorithm relies on georeferenced data. However, generating the training data from coco format should be straight forward as well. You basically need three annotations per image: the masks, the outlines and the distance transform. So you can load the coco annotations for one image, compute the mask, compute the outlines, then subtract the outlines from the masks and compute the distance transform (from scipy.ndimage import distance_transform_edt). Then save all three and you're good to go. The rasterization scripts can serve as rough guidance on how I did it with shapely polygons.

TobyZhouWei commented 8 months ago

Thank you for your guidance. If I still have questions about this topic, I will open a new issue and upload your suggestions in it.