zhu-xlab / DOFA

Code for Neural Plasticity-Inspired Foundation Model for Observing the Earth Crossing Modalities
MIT License
69 stars 4 forks source link

Problem loading the data & How to do inference after model training #12

Open mostaan66 opened 3 weeks ago

mostaan66 commented 3 weeks ago

Hi, I am trying to run the fine-tuning script for segmentation TUM Munich dataset and I have problem loading the data. Using your proposed structure: dataset/SegMunich/ ├── dataset/ │ ├── train.txt │ ├── val.txt | ├── train/ │ ├── img │ | ├── xxx.tif │ | ├── xxx.tif │ ├── label │ | ├── xxx.tif │ | ├── xxx.tif ├── val/ │ ├── img │ | ├── xxx.tif │ | ├── xxx.tif │ ├── label │ | ├── xxx.tif │ | ├── xxx.tif ...

and also the following which is introduced by mmsegmentation documentations: ├── data │ ├── my_dataset │ │ ├── img_dir │ │ │ ├── train │ │ │ │ ├── xxx{img_suffix} │ │ │ │ ├── yyy{img_suffix} │ │ │ │ ├── zzz{img_suffix} │ │ │ ├── val │ │ ├── ann_dir │ │ │ ├── train │ │ │ │ ├── xxx{seg_map_suffix} │ │ │ │ ├── yyy{seg_map_suffix} │ │ │ │ ├── zzz{seg_map_suffix} │ │ │ ├── val

With the first structure, I received the error that means it is not the structure that the dataloader expects. With the second one, I get error: File "/usr/local/lib/python3.10/dist-packages/mmengine/dataset/base_dataset.py", line 768, in _serialize_data data_bytes = np.concatenate(data_list) ValueError: need at least one array to concatenate.

Can you please give me a hint how to load the data and what is the correct file structure? Thanks in advance

xiong-zhitong commented 3 weeks ago

Thanks for your interest. We uploaded the preprocessed SegMunich dataset to huggingface. Please try this one.

FYI, we plan to release an improved version of DOFA weights. Please stay tuned :)

xiong-zhitong commented 1 week ago

Hi, I will look into it and get back to you this week. Btw, our new version of DOFA weights have been uploaded to HF https://huggingface.co/XShadow/DOFA https://huggingface.co/XShadow/DOFA.

Best, Zhitong

mostaan66 @.***> 于2024年8月27日周二 17:51写道:

Thanks for your response. I downloaded the new preprocessed SegMunich dataset and ran the fine-tuning code (by adapting a few lines of codes of config file and /usr/local/lib/python3.10/dist-packages/mmseg/datasets/isaid.py) for optical png files. However, I've not been yet able to run it for multispectral data, as it seems that it is not capable of loading tif files with more than 3 bands or I do not figure out how. Can you please briefly explain how to solve the issue? Thanks

— Reply to this email directly, view it on GitHub https://github.com/zhu-xlab/DOFA/issues/12#issuecomment-2312934626, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABYFJFIOZP7QID73BC3ZKHDZTSN6LAVCNFSM6AAAAABM4A6QS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMJSHEZTINRSGY . You are receiving this because you commented.Message ID: @.***>

mostaan66 commented 1 week ago

Hi Zhitong,

Thanks for your response. I’ve already deleted my issue from GitHub because I found the solution and managed to run the model for fine-tuning. The issue was related to preprocessing that was specific to RGB data (dict(type='PhotoMetricDistortion')), so I commented out that part of the code.

Currently, I'm facing another issue. I fine-tuned my model for a binary segmentation task, but during inference with the generated checkpoint, it doesn't produce binary predictions. Could you help clarify why this might be happening and how to resolve it?

Best, Mostaan

On Mon, Sep 2, 2024 at 4:31 PM Zhitong @.***> wrote:

Hi, I will look into it and get back to you this week. Btw, our new version of DOFA weights have been uploaded to HF https://huggingface.co/XShadow/DOFA https://huggingface.co/XShadow/DOFA.

Best, Zhitong

mostaan66 @.***> 于2024年8月27日周二 17:51写道:

Thanks for your response. I downloaded the new preprocessed SegMunich dataset and ran the fine-tuning code (by adapting a few lines of codes of config file and /usr/local/lib/python3.10/dist-packages/mmseg/datasets/isaid.py) for optical png files. However, I've not been yet able to run it for multispectral data, as it seems that it is not capable of loading tif files with more than 3 bands or I do not figure out how. Can you please briefly explain how to solve the issue? Thanks

— Reply to this email directly, view it on GitHub https://github.com/zhu-xlab/DOFA/issues/12#issuecomment-2312934626, or unsubscribe < https://github.com/notifications/unsubscribe-auth/ABYFJFIOZP7QID73BC3ZKHDZTSN6LAVCNFSM6AAAAABM4A6QS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMJSHEZTINRSGY>

. You are receiving this because you commented.Message ID: @.***>

— Reply to this email directly, view it on GitHub https://github.com/zhu-xlab/DOFA/issues/12#issuecomment-2324886667, or unsubscribe https://github.com/notifications/unsubscribe-auth/AVRLBLQA2FES4HJYXNXN3WDZURZE5AVCNFSM6AAAAABM4A6QS6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMRUHA4DMNRWG4 . You are receiving this because you authored the thread.Message ID: @.***>

zhitong-xiong commented 1 week ago

Could you please provide more info about your current model and output?

mostaan66 commented 2 days ago

I downloaded your previous "DOFA_ViT_base_e100.pt" model checkpoint and finetuned it for binary segmentation task:

Now, my fine-tuned model is : OFAViT( (fc_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (patch_embed): Dynamic_MLP_OFA( (weight_generator): TransformerWeightGenerator( (transformer_encoder): TransformerEncoder( (layers): ModuleList( (0): TransformerEncoderLayer( (self_attn): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) ) (linear1): Linear(in_features=128, out_features=2048, bias=True) (dropout): Dropout(p=False, inplace=False) (linear2): Linear(in_features=2048, out_features=128, bias=True) (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (dropout1): Dropout(p=False, inplace=False) (dropout2): Dropout(p=False, inplace=False) ) ) ) (fc_weight): Linear(in_features=128, out_features=196608, bias=True) (fc_bias): Linear(in_features=128, out_features=768, bias=True) ) (fclayer): FCResLayer( (nonlin1): ReLU(inplace=True) (nonlin2): ReLU(inplace=True) (w1): Linear(in_features=128, out_features=128, bias=True) (w2): Linear(in_features=128, out_features=128, bias=True) ) ) (blocks): ModuleList( (0-11): 12 x Block( (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (attn): Attention( (qkv): Linear(in_features=768, out_features=2304, bias=True) (q_norm): Identity() (k_norm): Identity() (attn_drop): Dropout(p=0.0, inplace=False) (proj): Linear(in_features=768, out_features=768, bias=True) (proj_drop): Dropout(p=0.0, inplace=False) ) (ls1): Identity() (drop_path1): Identity() (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) (mlp): Mlp( (fc1): Linear(in_features=768, out_features=3072, bias=True) (act): GELU(approximate='none') (drop1): Dropout(p=0.0, inplace=False) (norm): Identity() (fc2): Linear(in_features=3072, out_features=768, bias=True) (drop2): Dropout(p=0.0, inplace=False) ) (ls2): Identity() (drop_path2): Identity() ) ) (head_drop): Dropout(p=0.0, inplace=False) (head): Linear(in_features=768, out_features=45, bias=True) )

which is similar to the checkpoint you shared. However, unlike the original checkpoint it has check_point['state_dict'] keys:

odict_keys(['backbone.cls_token', 'backbone.pos_embed', 'backbone.patch_embed.weight_generator.weight_tokens', 'backbone.patch_embed.weight_generator.bias_token', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.self_attn.in_proj_weight', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.self_attn.in_proj_bias', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.self_attn.out_proj.weight', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.self_attn.out_proj.bias', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.linear1.weight', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.linear1.bias', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.linear2.weight', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.linear2.bias', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.norm1.weight', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.norm1.bias', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.norm2.weight', 'backbone.patch_embed.weight_generator.transformer_encoder.layers.0.norm2.bias', 'backbone.patch_embed.weight_generator.fc_weight.weight', 'backbone.patch_embed.weight_generator.fc_weight.bias', 'backbone.patch_embed.weight_generator.fc_bias.weight', 'backbone.patch_embed.weight_generator.fc_bias.bias', 'backbone.patch_embed.fclayer.w1.weight', 'backbone.patch_embed.fclayer.w1.bias', 'backbone.patch_embed.fclayer.w2.weight', 'backbone.patch_embed.fclayer.w2.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.1.norm1.weight', 'backbone.blocks.1.norm1.bias', 'backbone.blocks.1.attn.qkv.weight', 'backbone.blocks.1.attn.qkv.bias', 'backbone.blocks.1.attn.proj.weight', 'backbone.blocks.1.attn.proj.bias', 'backbone.blocks.1.norm2.weight', 'backbone.blocks.1.norm2.bias', 'backbone.blocks.1.mlp.fc1.weight', 'backbone.blocks.1.mlp.fc1.bias', 'backbone.blocks.1.mlp.fc2.weight', 'backbone.blocks.1.mlp.fc2.bias', 'backbone.blocks.2.norm1.weight', 'backbone.blocks.2.norm1.bias', 'backbone.blocks.2.attn.qkv.weight', 'backbone.blocks.2.attn.qkv.bias', 'backbone.blocks.2.attn.proj.weight', 'backbone.blocks.2.attn.proj.bias', 'backbone.blocks.2.norm2.weight', 'backbone.blocks.2.norm2.bias', 'backbone.blocks.2.mlp.fc1.weight', 'backbone.blocks.2.mlp.fc1.bias', 'backbone.blocks.2.mlp.fc2.weight', 'backbone.blocks.2.mlp.fc2.bias', 'backbone.blocks.3.norm1.weight', 'backbone.blocks.3.norm1.bias', 'backbone.blocks.3.attn.qkv.weight', 'backbone.blocks.3.attn.qkv.bias', 'backbone.blocks.3.attn.proj.weight', 'backbone.blocks.3.attn.proj.bias', 'backbone.blocks.3.norm2.weight', 'backbone.blocks.3.norm2.bias', 'backbone.blocks.3.mlp.fc1.weight', 'backbone.blocks.3.mlp.fc1.bias', 'backbone.blocks.3.mlp.fc2.weight', 'backbone.blocks.3.mlp.fc2.bias', 'backbone.blocks.4.norm1.weight', 'backbone.blocks.4.norm1.bias', 'backbone.blocks.4.attn.qkv.weight', 'backbone.blocks.4.attn.qkv.bias', 'backbone.blocks.4.attn.proj.weight', 'backbone.blocks.4.attn.proj.bias', 'backbone.blocks.4.norm2.weight', 'backbone.blocks.4.norm2.bias', 'backbone.blocks.4.mlp.fc1.weight', 'backbone.blocks.4.mlp.fc1.bias', 'backbone.blocks.4.mlp.fc2.weight', 'backbone.blocks.4.mlp.fc2.bias', 'backbone.blocks.5.norm1.weight', 'backbone.blocks.5.norm1.bias', 'backbone.blocks.5.attn.qkv.weight', 'backbone.blocks.5.attn.qkv.bias', 'backbone.blocks.5.attn.proj.weight', 'backbone.blocks.5.attn.proj.bias', 'backbone.blocks.5.norm2.weight', 'backbone.blocks.5.norm2.bias', 'backbone.blocks.5.mlp.fc1.weight', 'backbone.blocks.5.mlp.fc1.bias', 'backbone.blocks.5.mlp.fc2.weight', 'backbone.blocks.5.mlp.fc2.bias', 'backbone.blocks.6.norm1.weight', 'backbone.blocks.6.norm1.bias', 'backbone.blocks.6.attn.qkv.weight', 'backbone.blocks.6.attn.qkv.bias', 'backbone.blocks.6.attn.proj.weight', 'backbone.blocks.6.attn.proj.bias', 'backbone.blocks.6.norm2.weight', 'backbone.blocks.6.norm2.bias', 'backbone.blocks.6.mlp.fc1.weight', 'backbone.blocks.6.mlp.fc1.bias', 'backbone.blocks.6.mlp.fc2.weight', 'backbone.blocks.6.mlp.fc2.bias', 'backbone.blocks.7.norm1.weight', 'backbone.blocks.7.norm1.bias', 'backbone.blocks.7.attn.qkv.weight', 'backbone.blocks.7.attn.qkv.bias', 'backbone.blocks.7.attn.proj.weight', 'backbone.blocks.7.attn.proj.bias', 'backbone.blocks.7.norm2.weight', 'backbone.blocks.7.norm2.bias', 'backbone.blocks.7.mlp.fc1.weight', 'backbone.blocks.7.mlp.fc1.bias', 'backbone.blocks.7.mlp.fc2.weight', 'backbone.blocks.7.mlp.fc2.bias', 'backbone.blocks.8.norm1.weight', 'backbone.blocks.8.norm1.bias', 'backbone.blocks.8.attn.qkv.weight', 'backbone.blocks.8.attn.qkv.bias', 'backbone.blocks.8.attn.proj.weight', 'backbone.blocks.8.attn.proj.bias', 'backbone.blocks.8.norm2.weight', 'backbone.blocks.8.norm2.bias', 'backbone.blocks.8.mlp.fc1.weight', 'backbone.blocks.8.mlp.fc1.bias', 'backbone.blocks.8.mlp.fc2.weight', 'backbone.blocks.8.mlp.fc2.bias', 'backbone.blocks.9.norm1.weight', 'backbone.blocks.9.norm1.bias', 'backbone.blocks.9.attn.qkv.weight', 'backbone.blocks.9.attn.qkv.bias', 'backbone.blocks.9.attn.proj.weight', 'backbone.blocks.9.attn.proj.bias', 'backbone.blocks.9.norm2.weight', 'backbone.blocks.9.norm2.bias', 'backbone.blocks.9.mlp.fc1.weight', 'backbone.blocks.9.mlp.fc1.bias', 'backbone.blocks.9.mlp.fc2.weight', 'backbone.blocks.9.mlp.fc2.bias', 'backbone.blocks.10.norm1.weight', 'backbone.blocks.10.norm1.bias', 'backbone.blocks.10.attn.qkv.weight', 'backbone.blocks.10.attn.qkv.bias', 'backbone.blocks.10.attn.proj.weight', 'backbone.blocks.10.attn.proj.bias', 'backbone.blocks.10.norm2.weight', 'backbone.blocks.10.norm2.bias', 'backbone.blocks.10.mlp.fc1.weight', 'backbone.blocks.10.mlp.fc1.bias', 'backbone.blocks.10.mlp.fc2.weight', 'backbone.blocks.10.mlp.fc2.bias', 'backbone.blocks.11.norm1.weight', 'backbone.blocks.11.norm1.bias', 'backbone.blocks.11.attn.qkv.weight', 'backbone.blocks.11.attn.qkv.bias', 'backbone.blocks.11.attn.proj.weight', 'backbone.blocks.11.attn.proj.bias', 'backbone.blocks.11.norm2.weight', 'backbone.blocks.11.norm2.bias', 'backbone.blocks.11.mlp.fc1.weight', 'backbone.blocks.11.mlp.fc1.bias', 'backbone.blocks.11.mlp.fc2.weight', 'backbone.blocks.11.mlp.fc2.bias', 'neck.lateral_convs.0.conv.weight', 'neck.lateral_convs.0.conv.bias', 'neck.lateral_convs.1.conv.weight', 'neck.lateral_convs.1.conv.bias', 'neck.lateral_convs.2.conv.weight', 'neck.lateral_convs.2.conv.bias', 'neck.lateral_convs.3.conv.weight', 'neck.lateral_convs.3.conv.bias', 'neck.convs.0.conv.weight', 'neck.convs.0.conv.bias', 'neck.convs.1.conv.weight', 'neck.convs.1.conv.bias', 'neck.convs.2.conv.weight', 'neck.convs.2.conv.bias', 'neck.convs.3.conv.weight', 'neck.convs.3.conv.bias', 'decode_head.conv_seg.weight', 'decode_head.conv_seg.bias', 'decode_head.psp_modules.0.1.conv.weight', 'decode_head.psp_modules.0.1.bn.weight', 'decode_head.psp_modules.0.1.bn.bias', 'decode_head.psp_modules.0.1.bn.running_mean', 'decode_head.psp_modules.0.1.bn.running_var', 'decode_head.psp_modules.0.1.bn.num_batches_tracked', 'decode_head.psp_modules.1.1.conv.weight', 'decode_head.psp_modules.1.1.bn.weight', 'decode_head.psp_modules.1.1.bn.bias', 'decode_head.psp_modules.1.1.bn.running_mean', 'decode_head.psp_modules.1.1.bn.running_var', 'decode_head.psp_modules.1.1.bn.num_batches_tracked', 'decode_head.psp_modules.2.1.conv.weight', 'decode_head.psp_modules.2.1.bn.weight', 'decode_head.psp_modules.2.1.bn.bias', 'decode_head.psp_modules.2.1.bn.running_mean', 'decode_head.psp_modules.2.1.bn.running_var', 'decode_head.psp_modules.2.1.bn.num_batches_tracked', 'decode_head.psp_modules.3.1.conv.weight', 'decode_head.psp_modules.3.1.bn.weight', 'decode_head.psp_modules.3.1.bn.bias', 'decode_head.psp_modules.3.1.bn.running_mean', 'decode_head.psp_modules.3.1.bn.running_var', 'decode_head.psp_modules.3.1.bn.num_batches_tracked', 'decode_head.bottleneck.conv.weight', 'decode_head.bottleneck.bn.weight', 'decode_head.bottleneck.bn.bias', 'decode_head.bottleneck.bn.running_mean', 'decode_head.bottleneck.bn.running_var', 'decode_head.bottleneck.bn.num_batches_tracked', 'decode_head.lateral_convs.0.conv.weight', 'decode_head.lateral_convs.0.bn.weight', 'decode_head.lateral_convs.0.bn.bias', 'decode_head.lateral_convs.0.bn.running_mean', 'decode_head.lateral_convs.0.bn.running_var', 'decode_head.lateral_convs.0.bn.num_batches_tracked', 'decode_head.lateral_convs.1.conv.weight', 'decode_head.lateral_convs.1.bn.weight', 'decode_head.lateral_convs.1.bn.bias', 'decode_head.lateral_convs.1.bn.running_mean', 'decode_head.lateral_convs.1.bn.running_var', 'decode_head.lateral_convs.1.bn.num_batches_tracked', 'decode_head.lateral_convs.2.conv.weight', 'decode_head.lateral_convs.2.bn.weight', 'decode_head.lateral_convs.2.bn.bias', 'decode_head.lateral_convs.2.bn.running_mean', 'decode_head.lateral_convs.2.bn.running_var', 'decode_head.lateral_convs.2.bn.num_batches_tracked', 'decode_head.fpn_convs.0.conv.weight', 'decode_head.fpn_convs.0.bn.weight', 'decode_head.fpn_convs.0.bn.bias', 'decode_head.fpn_convs.0.bn.running_mean', 'decode_head.fpn_convs.0.bn.running_var', 'decode_head.fpn_convs.0.bn.num_batches_tracked', 'decode_head.fpn_convs.1.conv.weight', 'decode_head.fpn_convs.1.bn.weight', 'decode_head.fpn_convs.1.bn.bias', 'decode_head.fpn_convs.1.bn.running_mean', 'decode_head.fpn_convs.1.bn.running_var', 'decode_head.fpn_convs.1.bn.num_batches_tracked', 'decode_head.fpn_convs.2.conv.weight', 'decode_head.fpn_convs.2.bn.weight', 'decode_head.fpn_convs.2.bn.bias', 'decode_head.fpn_convs.2.bn.running_mean', 'decode_head.fpn_convs.2.bn.running_var', 'decode_head.fpn_convs.2.bn.num_batches_tracked', 'decode_head.fpn_bottleneck.conv.weight', 'decode_head.fpn_bottleneck.bn.weight', 'decode_head.fpn_bottleneck.bn.bias', 'decode_head.fpn_bottleneck.bn.running_mean', 'decode_head.fpn_bottleneck.bn.running_var', 'decode_head.fpn_bottleneck.bn.num_batches_tracked', 'auxiliary_head.conv_seg.weight', 'auxiliary_head.conv_seg.bias', 'auxiliary_head.convs.0.conv.weight', 'auxiliary_head.convs.0.bn.weight', 'auxiliary_head.convs.0.bn.bias', 'auxiliary_head.convs.0.bn.running_mean', 'auxiliary_head.convs.0.bn.running_var', 'auxiliary_head.convs.0.bn.num_batches_tracked'])

Maybe the problem is that I do not know how to infer it for the segmentation task. Can you please give me some hints how to solve this or provide some codes to handle this? Thanks in advance

xiong-zhitong commented 1 day ago

Hi, here is a simple example for model inference:

import torch
import matplotlib.pyplot as plt
from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import init_model, inference_model, show_result_pyplot

config_file = "configs/dofa_vit_segmunich.py"
checkpoint_file = "work_dirs/dofa_vit_segmunich/dofa_vit_segmunich.pth"

model = init_model(config_file, checkpoint_file, device='cpu')

# test a single image
img = 'data/SegMunich/img_dir/val/645.tif'
if not torch.cuda.is_available():
    model = revert_sync_batchnorm(model)
result = inference_model(model, img)

The result looks like this:

<SegDataSample(

    META INFORMATION
    img_shape: (256, 256)
    img_path: 'data/SegMunich/img_dir/val/645.tif'
    scale_factor: (2.0, 2.0)
    ori_shape: (128, 128)

    DATA FIELDS
    seg_logits: <PixelData(

            META INFORMATION

            DATA FIELDS
            data: tensor([[[ -6.8653,  -7.1571,  -7.7406,  ...,  -9.8622,  -8.7326,  -8.1678],
                         [ -7.1263,  -7.4104,  -7.9786,  ..., -10.1654,  -9.0342,  -8.4687],
                         [ -7.6483,  -7.9171,  -8.4546,  ..., -10.7719,  -9.6375,  -9.0703],
                         ...,
                         [ -7.4796,  -7.7309,  -8.2336,  ...,  -8.7707,  -7.7690,  -7.2682],
                         [ -6.8476,  -7.1284,  -7.6898,  ...,  -8.1029,  -7.2424,  -6.8122],
                         [ -6.5316,  -6.8271,  -7.4180,  ...,  -7.7690,  -6.9791,  -6.5842]],

                        [[ -1.9556,  -1.8963,  -1.7776,  ...,  -1.8071,  -2.0109,  -2.1128],
                         [ -1.9140,  -1.8806,  -1.8137,  ...,  -1.6550,  -1.8652,  -1.9703],
                         [ -1.8308,  -1.8492,  -1.8860,  ...,  -1.3508,  -1.5738,  -1.6853],
                         ...,
                         [ -1.2921,  -1.2279,  -1.0995,  ..., -11.2941, -10.0507,  -9.4290],
                         [ -1.4592,  -1.3533,  -1.1413,  ..., -10.1313,  -9.0501,  -8.5095],
                         [ -1.5428,  -1.4159,  -1.1622,  ...,  -9.5500,  -8.5498,  -8.0497]],

                        [[ -1.4377,  -1.3470,  -1.1656,  ...,  -1.4771,  -1.7191,  -1.8401],
                         [ -1.3096,  -1.2289,  -1.0676,  ...,  -1.3770,  -1.5953,  -1.7044],
                         [ -1.0534,  -0.9928,  -0.8716,  ...,  -1.1768,  -1.3476,  -1.4330],
                         ...,
                         [ -4.4227,  -4.4194,  -4.4128,  ...,  -9.2482,  -8.7694,  -8.5300],
                         [ -4.3288,  -4.2667,  -4.1426,  ...,  -8.8028,  -8.4594,  -8.2878],
                         [ -4.2818,  -4.1904,  -4.0075,  ...,  -8.5800,  -8.3045,  -8.1667]],

                        ...,

                        [[ -2.2751,  -2.1984,  -2.0449,  ...,  -3.3098,  -3.4472,  -3.5159],
                         [ -2.1042,  -2.0341,  -1.8939,  ...,  -3.2709,  -3.3439,  -3.3804],
                         [ -1.7626,  -1.7057,  -1.5920,  ...,  -3.1932,  -3.1373,  -3.1093],
                         ...,
                         [ -1.2615,  -1.2366,  -1.1867,  ...,  -9.2775,  -8.9401,  -8.7714],
                         [ -1.6097,  -1.5516,  -1.4353,  ...,  -8.8873,  -8.5578,  -8.3931],
                         [ -1.7838,  -1.7091,  -1.5596,  ...,  -8.6922,  -8.3667,  -8.2039]],

                        [[ -2.2086,  -2.2202,  -2.2434,  ...,  -3.4590,  -3.4219,  -3.4033],
                         [ -2.1867,  -2.2193,  -2.2843,  ...,  -3.5102,  -3.4289,  -3.3883],
                         [ -2.1430,  -2.2174,  -2.3663,  ...,  -3.6127,  -3.4431,  -3.3583],
                         ...,
                         [ -2.2061,  -2.4063,  -2.8068,  ..., -10.7627, -10.1027,  -9.7727],
                         [ -2.2716,  -2.3901,  -2.6269,  ...,  -9.8571,  -9.3015,  -9.0237],
                         [ -2.3044,  -2.3819,  -2.5370,  ...,  -9.4044,  -8.9009,  -8.6492]],

                        [[ -6.5529,  -6.5778,  -6.6276,  ...,  -8.4227,  -8.1410,  -8.0002],
                         [ -6.6344,  -6.6729,  -6.7499,  ...,  -8.3211,  -8.0194,  -7.8685],
                         [ -6.7974,  -6.8631,  -6.9945,  ...,  -8.1179,  -7.7760,  -7.6051],
                         ...,
                         [ -4.3621,  -4.3266,  -4.2556,  ...,  -9.5519,  -9.1874,  -9.0052],
                         [ -4.4509,  -4.3350,  -4.1034,  ...,  -9.0364,  -8.7521,  -8.6100],
                         [ -4.4952,  -4.3393,  -4.0273,  ...,  -8.7787,  -8.5345,  -8.4124]]])
        ) at 0x7f3d1c4c7ee0>
    pred_sem_seg: <PixelData(

            META INFORMATION

            DATA FIELDS
            data: tensor([[[3, 3, 3,  ..., 3, 3, 3],
                         [3, 3, 3,  ..., 3, 3, 3],
                         [3, 3, 3,  ..., 3, 3, 3],
                         ...,
                         [3, 3, 3,  ..., 4, 4, 4],
                         [3, 3, 3,  ..., 4, 4, 4],
                         [3, 3, 3,  ..., 4, 4, 4]]])
        ) at 0x7f3d1c4c7d90>
) at 0x7f3d1c4c7d00>

I think "pred_sem_seg" is what you want.