naver-ai / egtr

[CVPR 2024 Best paper award candidate] EGTR: Extracting Graph from Transformer for Scene Graph Generation
https://arxiv.org/abs/2404.02072
Apache License 2.0
50 stars 1 forks source link

about pretrained #6

Closed xfufu0724 closed 1 month ago

xfufu0724 commented 1 month ago

Hello,I encountered some issues when loading the checkpoint. When I want to train the EGTR model,Which .ckpt file should be specifically used for the 'pretrained' in the command line: python train_egtr.py --data_path dataset/visual_genome --output_path $OUTPUT_PATH --pretrained $PRETRAINED_PATH --memo $MEMO ?

Looking forward to your response. Thank you.

xfufu0724 commented 1 month ago

@jinbae

jinbae commented 1 month ago

We used pre-trained object detector to train EGTR, and our pre-trained object detector are publicly available. Please refer to the readme.

image

Download the pre-trained object detector and set the downloaded path as PRETRAINED_PATH.

xfufu0724 commented 1 month ago

I have downloaded pre-trained object detector and extracted it to the path of the code. 微信图片_20240723172802

The parameter pretrained seems to require loading a JSON files, i tried to load "epoch=03-validation_loss=1.71test26446V100-PCIE-32GB.json" file. But when I load the JSON file, the following error occurs: _OSError: Unable to load weights from pytorch checkpoint file for './pretrained_detrSenseTimedeformable-detr/batch32epochs150_50lr1e-05_0.0001visual_genomefinetune/version_0/config.json' at './pretrained_detrSenseTimedeformable-detr/batch32epochs150_50lr1e-05_0.0001__visual_genomefinetune/version_0/config.json'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set fromtf=True.

hanchuqi commented 1 week ago

@xfufu0724 I'm having the same issue, how did you fix it?

xfufu0724 commented 1 week ago

@xfufu0724 我遇到了同样的问题,您是如何解决的?

具体我不太记得了,你看看能不能找到设置from_tf的代码,似乎改成from_tf=True就行了?