zhangce01 / HiKER-SGG

[CVPR 2024] Code for HiKER-SGG: Hierarchical Knowledge Enhanced Robust Scene Graph Generation
https://zhangce01.github.io/HiKER-SGG/
MIT License
49 stars 2 forks source link

Get the graph #10

Closed laopidu closed 1 month ago

laopidu commented 1 month ago

Hi,

I have encountered an issue that I am unable to resolve on my own and was hoping you might be able to provide some guidance.

I want to use the pre-trained model to predict images in or out of the dataset to complete SGG to get a graph that represents the objects and their relationships within the image. Which scripts and pre-trained models can I use to debug?

Thank you in advance for your time and assistance. I appreciate your help with this matter.

zhangce01 commented 1 month ago

The output scene graphs will be available in the form of a dictionary pred_entry. You can then refer to this link to map the predicted indices to actual objects and predicates.

laopidu commented 1 month ago

Thank you very much for your answer and guidance! This is very helpful to me.

I'd like to ask a follow-up question - could you please provide more details on how to actually obtain this dictionary of graphs? Are there specific steps or code I should use to generate and access the pred_entry dictionary like pred_classes and pred_rel_inds?

I appreciate your continued assistance with this.

zhangce01 commented 1 month ago

We generate the pred_entry dictionary in val_batch, which contains all the inference results for a batch of images. If you can run the code for training in hikersgg_predcls_train.ipynb, you may simply print out this dictionary to see what it contains. For deploying the model on your own data, you need to first pre-train a model on the VG dataset using our code (e.g., hikersgg_predcls_train.ipynb), then re-implement the dataloader to load your own images and use the code in hikersgg_predcls_test.ipynb to perform inference.

laopidu commented 1 month ago

Thanks for the guidance!

By running hikersgg_predcls_train, I found that the matching accuracy of the generated pred_entries.jsonl dictionary is not that high. I would like to know that if the hyperparameters in the code (such as epoch, etc.) have been modified?

This is the printing setting

    # Open a file to write the pred_entry data
    with open('pred_entries.jsonl', 'a') as f:
        with autocast():
            det_res = detector[b]
        if conf.num_gpus == 1:
            det_res = [det_res]

        for i, (boxes_i, objs_i, obj_scores_i, rels_i, pred_scores_i) in enumerate(det_res):
            gt_entry = {
                'gt_classes': val.gt_classes[batch_num + i].copy(),
                'gt_relations': val.relationships[batch_num + i].copy(),
                'gt_boxes': val.gt_boxes[batch_num + i].copy(),
            }
            assert np.all(objs_i[rels_i[:, 0]] > 0) and np.all(objs_i[rels_i[:, 1]] > 0)

            pred_entry = {
                'pred_boxes': boxes_i.tolist(),  # Convert numpy array to list
                'pred_classes': objs_i.tolist(),
                'pred_rel_inds': rels_i.tolist(),
                'obj_scores': obj_scores_i.tolist(),
                'rel_scores': pred_scores_i.tolist(),
            }

            # Write the pred_entry to the file
            json.dump(pred_entry, f)
            f.write('\n')  # Add a newline for each entry

            eval_entry(conf.mode, gt_entry, pred_entry, evaluator, evaluator_multiple_preds,
                       evaluator_list, evaluator_multiple_preds_list)

and this is the part of printing results

{'pred_boxes': [[0.0, 0.0, 591.421875, 412.203125], [56.078125, 226.625, 590.84375, 441.6875], [0.0, 0.0, 164.765625, 398.328125], [323.75, 207.546875, 589.6875, 398.90625], [310.453125, 45.09375, 367.6875, 313.921875], [444.578125, 0.0, 591.421875, 109.265625], [273.453125, 194.25, 330.109375, 379.25], [175.171875, 182.6875, 219.109375, 376.359375], [90.1875, 8.671875, 148.0, 141.640625], [530.71875, 252.0625, 590.265625, 373.46875], [269.984375, 217.375, 330.6875, 291.375], [352.65625, 235.296875, 410.46875, 305.828125], [286.171875, 272.296875, 322.015625, 367.109375], [177.484375, 211.59375, 217.953125, 287.90625], [179.796875, 283.28125, 212.75, 371.15625], [286.171875, 357.859375, 322.015625, 378.671875], [272.296875, 209.28125, 294.84375, 239.921875], [248.015625, 234.71875, 268.25, 261.890625], [235.875, 228.359375, 256.109375, 259.0], [330.6875, 233.5625, 363.0625, 250.90625], [56.078125, 241.65625, 584.484375, 435.90625], [178.0625, 360.75, 216.796875, 380.40625], [236.453125, 234.71875, 266.515625, 260.734375]], 'pred_classes': [136, 114, 22, 124, 30, 145, 78, 78, 115, 26, 111, 26, 87, 111, 87, 112, 3, 11, 11, 54, 124, 120, 11], 'pred_rel_inds': [[7, 1], [6, 1], [7, 20], [6, 20], [9, 3], [11, 3], [6, 16], [9, 20], [0, 20], [6, 15], [11, 20], [6, 12], [7, 21], [6, 10], [0, 1], [7, 14], [18, 20], [7, 13], [17, 20], [6, 3], [22, 20], [18, 1], [2, 20], [22, 1], [16, 6], [17, 1], [0, 3], [7, 3], [2, 3], [8, 2], [2, 1], [9, 1], [0, 2], [11, 1], [7, 18], [2, 0], [7, 22], [6, 17], [0, 7], [4, 19], [1, 20], [4, 1], [4, 20], [6, 22], [7, 17], [19, 4], [2, 7], [7, 0], [0, 6], [6, 18], [4, 3], [6, 19], [4, 6], [1, 3], [1, 2], [17, 3], [7, 6], [8, 20], [5, 20], [4, 0], [6, 0], [3, 20], [18, 3], [1, 0], [5, 3], [2, 5], [11, 4], [20, 0], [6, 7], [3, 1], [22, 3], [2, 6], [20, 2], [2, 8], [17, 16], [8, 0], [3, 11], [3, 0], [20, 1], [5, 1], [0, 11], [22, 16], [20, 9], [8, 1], [8, 7], [22, 7], [22, 0], [0, 9], [3, 9], [17, 0], [20, 7], [18, 7], [4, 2], [18, 0], [16, 7], [4, 11], [7, 2], [18, 17], [20, 4], [0, 4], [20, 6], [17, 2], [17, 18], [4, 7], [15, 6], [17, 7], [20, 3], [14, 13], [0, 5], [11, 2], [11, 0], [20, 18], [7, 15], [20, 11], [5, 0], [5, 2], [22, 2], [6, 4], [20, 22], [1, 7], [22, 18], [18, 2], [9, 0], [21, 7], [20, 17], [3, 2], [8, 3], [13, 14], [7, 16], [0, 8], [11, 6], [14, 7], [18, 16], [1, 22], [1, 9], [13, 7], [18, 22], [17, 22], [1, 18], [10, 12], [2, 4], [1, 6], [12, 15], [1, 17], [17, 6], [0, 19], [14, 21], [2, 18], [1, 4], [12, 10], [2, 11], [10, 16], [22, 17], [3, 4], [12, 6], [19, 16], [19, 20], [0, 18], [2, 22], [9, 7], [0, 22], [22, 6], [1, 11], [16, 10], [19, 0], [3, 7], [8, 4], [2, 17], [7, 4], [0, 17], [19, 3], [7, 8], [8, 6], [3, 6], [8, 5], [18, 6], [10, 6], [9, 2], [2, 9], [19, 1], [3, 19], [1, 5], [11, 7], [17, 11], [6, 11], [13, 21], [5, 7], [6, 2], [4, 9], [4, 5], [11, 19], [11, 17], [6, 21], [22, 11], [10, 19], [7, 9], [11, 22], [16, 20], [0, 16], [18, 11], [7, 12], [17, 4], [5, 6], [19, 11], [22, 4], [11, 18], [6, 9], [1, 19], [16, 22], [9, 5], [9, 22], [9, 6], [11, 5], [16, 1], [9, 18], [7, 10], [3, 5], [5, 9], [18, 4], [4, 22], [19, 6], [4, 17], [5, 4], [15, 12], [10, 15], [16, 0], [3, 17], [20, 5], [9, 17], [5, 11], [8, 18], [4, 18], [21, 15], [15, 21], [3, 18], [3, 22], [16, 17], [6, 13], [8, 22], [1, 8], [8, 9], [15, 20], [7, 19], [16, 18], [17, 19], [14, 15], [11, 9], [8, 17], [9, 11], [22, 19], [6, 14], [17, 9], [14, 12], [4, 8], [7, 11], [18, 8], [8, 11], [18, 9], [13, 10], [6, 5], [22, 9], [22, 8], [20, 19], [1, 16], [5, 8], [7, 5], [17, 8], [6, 8], [12, 19], [9, 4], [12, 21], [19, 17], [12, 14], [4, 16], [20, 8], [21, 14], [11, 16], [19, 10], [21, 13], [3, 8], [18, 19], [19, 22], [8, 16], [15, 1], [19, 7], [13, 19], [20, 15], [16, 19], [0, 15], [15, 10], [15, 7], [21, 20], [10, 13], [17, 5], [16, 12], [16, 3], [19, 18], [1, 15], [20, 16], [16, 4], [22, 5], [10, 20], [16, 15], [19, 15], [18, 5], [13, 12], [10, 0], [14, 19], [5, 19], [14, 10], [15, 19], [11, 8], [5, 16], [5, 22], [12, 20], [10, 14], [5, 18], [5, 17], [15, 0], [14, 20], [10, 17], [13, 0], [10, 21], [16, 11], [20, 21], [10, 7], [13, 15], [12, 0], [9, 19], [8, 19], [12, 16], [21, 6], [21, 1], [21, 19], [0, 10], [21, 12], [0, 21], [15, 14], [4, 15], [13, 6], [16, 8], [10, 1], [13, 20], [2, 19], [12, 1], [1, 21], [14, 0], [19, 12], [14, 6], [12, 7], [19, 5], [10, 4], [9, 8], [21, 0], [19, 2], [14, 1], [12, 13], [0, 13], [15, 16], [10, 22], [15, 3], [9, 16], [19, 9], [13, 18], [20, 12], [20, 10], [13, 1], [14, 18], [19, 21], [2, 16], [16, 21], [0, 12], [19, 8], [16, 9], [4, 12], [10, 18], [2, 21], [14, 22], [12, 4], [12, 17], [1, 12], [21, 10], [13, 22], [4, 10], [20, 14], [14, 17], [16, 13], [21, 18], [16, 5], [3, 16], [12, 22], [15, 22], [0, 14], [13, 16], [15, 4], [15, 17], [13, 17], [15, 18], [15, 13], [12, 18], [20, 13], [1, 10], [21, 22], [3, 15], [19, 13], [12, 3], [1, 14], [16, 14], [2, 15], [10, 3], [21, 17], [8, 21], [16, 2], [8, 13], [21, 3], [17, 10], [5, 15], [1, 13], [8, 15], [21, 8], [22, 10], [21, 16], [14, 16], [13, 8], [4, 21], [14, 3], [21, 4], [19, 14], [18, 21], [18, 10], [22, 15], [13, 3], [17, 15], [10, 11], [3, 21], [11, 15], [3, 12], [3, 10], [2, 13], [18, 13], [22, 13], [18, 15], [2, 14], [17, 21], [15, 11], [22, 21], [14, 8], [5, 21], [8, 14], [15, 5], [9, 15], [11, 10], [13, 4], [8, 10], [12, 11], [10, 5], [17, 13], [9, 21], [21, 2], [5, 10], [13, 2], [2, 10], [15, 9], [15, 8], [17, 12], [14, 4], [10, 8], [15, 2], [4, 14], [21, 9], [4, 13], [22, 12], [2, 12], [10, 2], [5, 12], [21, 5], [21, 11], [14, 2], [3, 13], [11, 21], [10, 9], [18, 12], [3, 14], [8, 12], [13, 5], [12, 5], [14, 11], [12, 8], [18, 14], [13, 11], [12, 2], [5, 13], [12, 9], [14, 9], [22, 14], [17, 14], [13, 9], [11, 12], [9, 10], [5, 14], [14, 5], [9, 13], [9, 14], [9, 12], [11, 13], [11, 14]], 'obj_scores': [1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0], 'rel_scores': [[0.5414498448371887, 0.00021737918723374605, 0.0003916021669283509, 7.875756273278967e-05, 0.0005632897955365479, 1.0285631105944049e-05, 0.012644953094422817, 3.288299558334984e-05, 1.7406885035597952e-06, 9.38242905590414e-08, 8.839080692268908e-06, 1.8517586795496754e-05, 4.558180535241263e-06, 2.2304925550997723e-06, 3.811560964095406e-05...

It shows 23 boxes and 23 objects,but too many pred_rel_inds,are they all the predicted ralations among these 23 objects?And how to correctly combined them into sentences without serial number? What does it mean by [7,1] and others in pred_rel_inds?

zhangce01 commented 1 month ago

Yes, they are all the predicted relations among these 23 objects. In our setting, we predict 100 triplets for each image, but you can select the top predictions by yourself. For [7, 1], You can find the 1st and 7th objects in pred_classes, 114 (sidewalk) and 78 (man). So it should be man-(relation)-sidewalk, and you can also find the predicted probabilities for corresponding relation/predicate in rel_scores.

laopidu commented 1 month ago

[7,1] and [6,1] show the objects,so where can i get the corresponding relation/predicate?

zhangce01 commented 1 month ago

It should be in rel_scores, which contains the predicted probabilities for predicates.

Maybe you can try to print out the shape of rel_scores.

laopidu commented 1 month ago

The rel_scores get a shape of [1,506,51],and all of the number are smaller than 0.001. What did i do wrong?

zhangce01 commented 1 month ago

I assume you also have 506 (all 23 x 22) object pairs in pred_rel_inds. For each pair, you will obtain 51 predicted probabilities for all predicates (50 relations and a special no-relation predicate). You can use argmax to identify the highest-probability predicate for each object pair. Then, you can sort these probabilities (exclude the first no-relation predicate) to get the top predictions.

laopidu commented 1 month ago

Do these data reflect the model's understanding and prediction of the current image being processed, rather than pointing to a specific image file? If not, where is the path to the image?

zhangce01 commented 1 month ago

We currently use a dataloader to load images from the Visual Genome dataset, so this output corresponds to a specific image in the VG dataset. If you want to test our model on your own dataset, please rewrite the dataloader, where you should use your image paths to load images. For testing with a few examples, simply replace b (which currently represents a batch of image inputs) in det_res = detector[b] with your image.

laopidu commented 1 month ago

After changing the code several times, I still can't run hikersgg_predcls_test. I change the directory of b to specify the first thirty pictures of vg dataset to test to get the pred_entry of each picture. When I run hikersgg_predcls_test normally without modification, it evaluates 26446 pictures.

Here is the result

======================predcls  recall with constraint============================
R@20: 0.010330
R@50: 0.019812
R@100: 0.027930
======================predcls  recall without constraint============================
R@20: 0.010531
R@50: 0.026971
R@100: 0.051066

======================predcls  mean recall with constraint============================
mR@20:  0.003870488730891438
mR@50:  0.012130850294951583
mR@100:  0.016847753934387925

======================predcls  mean recall without constraint============================
mR@20:  0.004409096666683437
mR@50:  0.014361383805287625
mR@100:  0.026735767829460525
zhangce01 commented 1 month ago

I cannot fully understand your inquiries. It seems that you have already successfully evaluated the method on the full VG test set but the results are not satisfactory. Please ensure that you are loading the weights of a trained model from hikersgg_predcls_train.

Besides, could you please directly contact me via email: cezhang@cs.cmu.edu for your issues?

laopidu commented 1 month ago

Thank you for your continuous help. I have sent you an email before, but for unknown reasons, you may not have received it. And i send a new email for my issues,hope to get your reply!And my email:laopidu@163.com.