Open huytuong010101 opened 2 years ago
Hi,
full CAT-Net takes 24-channel input (3 for RGB and 21 for DCT) and the DCT stream takes 21-channel input.
Due to the input difference, you should uncomment this line: https://github.com/mjkwon2021/CAT-Net/blob/26aa8b06c9a0fa1fe427dbd77b8d327c92b87cba/tools/infer.py#L76
@CauchyComplete it work for me. Thank you so much <3
Glad it helped ☺️
Get error. Please ignore line number.
<Splicing.data.dataset_arbitrary.arbitrary object at 0x7fa645dcb160>(116) crop_size=None, grid_crop=True, blocks=('RGB', 'DCTvol', 'qtable'), mode=arbitrary, read_from_jpeg=True, class_weight=tensor([1., 1.])
=> Cannot load pretrained RGB
=> Cannot load pretrained DCT
=> loading model from ./output/splicing_dataset/CAT_full/DCT_only_v2.pth.tar
Traceback (most recent call last):
File "tools/infer.py", line 240, in
......
size mismatch for last_layer.0.weight: copying a param with shape torch.Size([672, 672, 1, 1]) from checkpoint, the shape in current model is torch.Size([360, 360, 1, 1]).
size mismatch for last_layer.0.bias: copying a param with shape torch.Size([672]) from checkpoint, the shape in current model is torch.Size([360]).
size mismatch for last_layer.1.weight: copying a param with shape torch.Size([672]) from checkpoint, the shape in current model is torch.Size([360]).
size mismatch for last_layer.1.bias: copying a param with shape torch.Size([672]) from checkpoint, the shape in current model is torch.Size([360]).
size mismatch for last_layer.1.running_mean: copying a param with shape torch.Size([672]) from checkpoint, the shape in current model is torch.Size([360]).
size mismatch for last_layer.1.running_var: copying a param with shape torch.Size([672]) from checkpoint, the shape in current model is torch.Size([360]).
size mismatch for last_layer.3.weight: copying a param with shape torch.Size([2, 672, 1, 1]) from checkpoint, the shape in current model is torch.Size([2, 360, 1, 1]).
@Vadim2S
It's hard to debug your personalized code, especially when you didn't specify your changes.
As far as I can tell, from the message => Cannot load pretrained RGB
and the code blocks=('RGB', 'DCTvol', 'qtable')
, I think you are trying to use the full model. But loading model from DCT_only_v2.pth.tar
should be used when you want to use DCT stream only.
Plus, next time you ask a question, please be specific and try not to just copy-paste your error message.
Hi, thank you for perfect project <3 I run
python tool/infers.py
With this lineargs = argparse.Namespace(cfg='experiments/CAT_DCT_only.yaml', opts=['TEST.MODEL_FILE', 'output/splicing_dataset/CAT_DCT_only/DCT_only_v2.pth.tar', 'TEST.FLIP_TEST', 'False', 'TEST.NUM_SAMPLES', '0'])
To use CAT_DCT_only, but i got this error, can you help me fix it. Thank you