orrzohar / PROB

[CVPR 2023] Official Pytorch code for PROB: Probabilistic Objectness for Open World Object Detection
Apache License 2.0
109 stars 15 forks source link

Error on fine-tuning #35

Closed sf-pear closed 1 year ago

sf-pear commented 1 year ago

Hi again Orr,

I'm having another issue where I can't run the fine-tuning part of the tasks. So far I trained task 1 and 2 but fine-tuning for task 2 is giving me an error. It seems like the script only runs with the flag --freeze_prob_model, do you know why this might be happening?

# train task 2
PY_ARGS=${@:1}
python -u main_open_world.py \
    --output_dir "${EXP_DIR}/t2" --dataset fathomnet --PREV_INTRODUCED_CLS 10 --CUR_INTRODUCED_CLS 2\
    --train_set 'task2_train' --test_set 'all_test' --epochs 51\
    --model_type 'prob' --obj_loss_coef 8e-4 --obj_temp 1.3 --freeze_prob_model\
    --wandb_name "${WANDB_NAME}_t2"\
    --exemplar_replay_selection --exemplar_replay_max_length 1743 --exemplar_replay_dir ${WANDB_NAME}\
    --exemplar_replay_prev_file "task1_train_ft.txt" --exemplar_replay_cur_file "task2_train_ft.txt"\
    --pretrain "${EXP_DIR}/t1/checkpoint0040.pth" --lr 2e-5\
    ${PY_ARGS}

# fine tune task 2
PY_ARGS=${@:1}
python -u main_open_world.py \
    --output_dir "${EXP_DIR}/t2_ft" --dataset fathomnet --PREV_INTRODUCED_CLS 10 --CUR_INTRODUCED_CLS 2 \
    --train_set "${WANDB_NAME}/task2_train_ft" --test_set 'all_test' --epochs 111 --lr_drop 40\
    --model_type 'prob' --obj_loss_coef 8e-4 --obj_temp 1.3\
    --wandb_name "${WANDB_NAME}_t2_ft"\
    --pretrain "${EXP_DIR}/t2/checkpoint0050.pth"\
    ${PY_ARGS}
Dataset OWDetection
    Number of datapoints: 1669
    Root location: /home/sabrina/code/PROB/data/OWOD
    [['test'], Compose(
    <datasets.transforms.RandomResize object at 0x7f053807ef50>
    Compose(
    <datasets.transforms.ToTensor object at 0x7f053807ef80>
    <datasets.transforms.Normalize object at 0x7f053807f0a0>
)
)]
Initialized from the pre-training model
<All keys matched successfully>
Start training from epoch 51 to 111
Traceback (most recent call last):
  File "/home/sabrina/code/PROB/main_open_world.py", line 475, in <module>
    main(args)
  File "/home/sabrina/code/PROB/main_open_world.py", line 335, in main
    train_stats = train_one_epoch(
  File "/home/sabrina/code/PROB/engine.py", line 41, in train_one_epoch
    prefetcher = data_prefetcher(data_loader, device, prefetch=True)
  File "/home/sabrina/code/PROB/datasets/data_prefetcher.py", line 21, in __init__
    self.preload()
  File "/home/sabrina/code/PROB/datasets/data_prefetcher.py", line 25, in preload
    self.next_samples, self.next_targets = next(self.loader)
  File "/home/sabrina/mambaforge/envs/prob/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
    data = self._next_data()
  File "/home/sabrina/mambaforge/envs/prob/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1347, in _next_data
    return self._process_data(data)
  File "/home/sabrina/mambaforge/envs/prob/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1373, in _process_data
    data.reraise()
  File "/home/sabrina/mambaforge/envs/prob/lib/python3.10/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/sabrina/mambaforge/envs/prob/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/sabrina/mambaforge/envs/prob/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/sabrina/mambaforge/envs/prob/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/sabrina/code/PROB/datasets/torchvision_datasets/open_world.py", line 328, in __getitem__
    img, target = self.transforms[-1](img, target)
  File "/home/sabrina/code/PROB/datasets/transforms.py", line 275, in __call__
    image, target = t(image, target)
  File "/home/sabrina/code/PROB/datasets/transforms.py", line 233, in __call__
    return self.transforms2(img, target)
  File "/home/sabrina/code/PROB/datasets/transforms.py", line 275, in __call__
    image, target = t(image, target)
  File "/home/sabrina/code/PROB/datasets/transforms.py", line 207, in __call__
    return resize(img, target, size, self.max_size)
  File "/home/sabrina/code/PROB/datasets/transforms.py", line 125, in resize
    scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
RuntimeError: The size of tensor a (0) must match the size of tensor b (4) at non-singleton dimension 0
orrzohar commented 1 year ago

Hi @sf-pear,

Are you saying that this does not happen when the flag is set?

Because the error here is common with data loading issues (specifically, when an image does not have any GT detections in it). I don't see any reason it will have anything to do with whether or not you do not update the prob_obj head.

Looking at it here, it looks like you are training on a different dataset in t2_ft (as expected), so this may be due to this, and not the freeze_prob_model flag.

Best, Orr

sf-pear commented 1 year ago

No, The freeze prob flag works. Task 2 runs with --freeze_prob_model:

# train task 2
PY_ARGS=${@:1}
python -u main_open_world.py \
    --output_dir "${EXP_DIR}/t2" --dataset fathomnet --PREV_INTRODUCED_CLS 10 --CUR_INTRODUCED_CLS 2\
    --train_set 'task2_train' --test_set 'all_test' --epochs 51\
    --model_type 'prob' --obj_loss_coef 8e-4 --obj_temp 1.3 --freeze_prob_model\
    --wandb_name "${WANDB_NAME}_t2"\
    --exemplar_replay_selection --exemplar_replay_max_length 1743 --exemplar_replay_dir ${WANDB_NAME}\
    --exemplar_replay_prev_file "task1_train_ft.txt" --exemplar_replay_cur_file "task2_train_ft.txt"\
    --pretrain "${EXP_DIR}/t1/checkpoint0040.pth" --lr 2e-5\
    ${PY_ARGS}

but fine-tuning won't work:

# fine tune task 2
PY_ARGS=${@:1}
python -u main_open_world.py \
    --output_dir "${EXP_DIR}/t2_ft" --dataset fathomnet --PREV_INTRODUCED_CLS 10 --CUR_INTRODUCED_CLS 2 \
    --train_set "${WANDB_NAME}/task2_train_ft" --test_set 'all_test' --epochs 111 --lr_drop 40\
    --model_type 'prob' --obj_loss_coef 8e-4 --obj_temp 1.3\
    --wandb_name "${WANDB_NAME}_t2_ft"\
    --pretrain "${EXP_DIR}/t2/checkpoint0050.pth"\
    ${PY_ARGS}

While trying to figure out why the error was happening, I set the --freeze_prob_model flag for fine-tuning task 2. By setting the flag, the fine-tuning ran, but without it won't run. Doe that make sense?

I tried to find an issue in transforms.py but not sure what is triggering it. I also double checked my dataset and I have no images with less than 1 annotation.

orrzohar commented 1 year ago

Hi @sf-pear,

I really doubt that this has to do with the freeze_prob_model flag.

The "transforms" is part of the data loader, so this happens during dataloading and is unrelated to model inference.

I have had issues like this when some images have no ground-truth detections on them. This may be the case here. What happens is that if you look at the following:

    scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])

then what happens is that boxes is just [], while torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) remains 4x1 (or 1x4, i don't really remember).

If you want to make sure this is the case, then:

  1. set the number of workers to 0
  2. set a stop-point (import ipdb; ipdb.set_trace()) by:
try:
    scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
except:
   import ipdb; ipdb.set_trace()

This will allow you to see the variables when this fails. But I am fairly confident that what I said is what is happening.

Now, the root cause for this issue is usually having the data not set up properly, and you effectively have some images in t2_ft that have no T1 and T2 objects. troubleshoot, you first need to find which image causes this (using ipdb, this should be easy) and then see if it really doesn't have T1 and T2 objects. Are the associated labels (.xml file) looking OK?

Let me know how that goes, Best, Orr

sf-pear commented 1 year ago

Thanks Orr!

When debugging I can see that one of the image_ids it fails on has 7 annotations, and both classes are present in task 1.

The target:

{'image_id': tensor([16531]), 'org_image_id': tensor([49., 54., 53., 51., 49.]), 'labels': tensor([], dtype=torch.int64), 'area': tensor([]), 'boxes': tensor([]), 'orig_size': tensor([1080, 1920]), 'size': tensor([1080, 1920]), 'iscrowd': tensor([], dtype=torch.uint8)}

Annotation file for 16531:

<annotation>
    <folder />
    <filename>16531.png</filename>
    <path>/home/sabrina/code/PROB/data/OWOD/ImageSets/16531.png</path>
    <source>
        <database />
    </source>
    <size>
        <width>1920</width>
        <height>1080</height>
        <depth />
    </size>
    <segmented />
    <object>
        <name>Urchin</name>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>847</xmin>
            <ymin>514</ymin>
            <xmax>889</xmax>
            <ymax>555</ymax>
        </bndbox>
    </object>
    <object>
        <name>Urchin</name>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>342</xmin>
            <ymin>763</ymin>
            <xmax>396</xmax>
            <ymax>801</ymax>
        </bndbox>
    </object>
    <object>
        <name>Sea star</name>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>1266</xmin>
            <ymin>348</ymin>
            <xmax>1305</xmax>
            <ymax>383</ymax>
        </bndbox>
    </object>
    <object>
        <name>Urchin</name>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>860</xmin>
            <ymin>464</ymin>
            <xmax>901</xmax>
            <ymax>505</ymax>
        </bndbox>
    </object>
    <object>
        <name>Urchin</name>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>781</xmin>
            <ymin>644</ymin>
            <xmax>824</xmax>
            <ymax>690</ymax>
        </bndbox>
    </object>
    <object>
        <name>Urchin</name>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>1464</xmin>
            <ymin>701</ymin>
            <xmax>1513</xmax>
            <ymax>741</ymax>
        </bndbox>
    </object>
    <object>
        <name>Sea star</name>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>1038</xmin>
            <ymin>299</ymin>
            <xmax>1076</xmax>
            <ymax>328</ymax>
        </bndbox>
    </object>
</annotation>

To me it looks like it should work. Do you see anything wrong with the annotation file? Not sure where to go from here.

Here are the classes for reference.

UNK_CLASS = ["unknown"]

VOC_COCO_CLASS_NAMES = {}

T1_CLASS_NAMES = [
    'Urchin', 'Fish', 'Sea star', 'Anemone', 'Sea cucumber', 
    'Sea pen', 'Sea fan', 'Worm', 'Crab', 'Gastropod'
]

T2_CLASS_NAMES = [
    'Shrimp', 'Soft coral'
]

T3_CLASS_NAMES = [
    'Glass sponge', 'Feather star'
]

T4_CLASS_NAMES = [
    'Eel', 'Squat lobster', 'Barnacle', 'Stony coral', 'Black coral', 'Sea spider'
]

VOC_COCO_CLASS_NAMES["fathomnet"] = tuple(itertools.chain(T1_CLASS_NAMES, T2_CLASS_NAMES, T3_CLASS_NAMES, T4_CLASS_NAMES, UNK_CLASS))
orrzohar commented 1 year ago

Hi @sf-pear,

I believe I see the issue! This is happening only on this image because it only has task1 objects -- and we are really lucky it did!

You cannot have "train" in the name of the "ft" file.

This comes from how the dataloader parses the different files:

https://github.com/orrzohar/PROB/blob/10b6518f90495e07b7baf0d1bfa353f0e583eb8e/datasets/torchvision_datasets/open_world.py#L285-L290

As you can see, if you have "train" in the name of "image_set", you will go down the logic of a training (and not fine-tuning) dataset -- and remove both previously known and unknown objects. But in fine-tuning, you should have access to ALL the objects. As you only have T1 objects here, they were all removed, which is why you got the error.

To fix this, you could either edit the code above to be something more robust, or just change the file to match the style I had (e.g., t2_ft or anything that does not have "train" or "test" in the file name).

Best, Orr

sf-pear commented 1 year ago

Brilliant! Just changed the filename and it is running. Thank you so much!