Closed LRpz closed 7 months ago
Hi, the error message you're encountering indicates a mismatch in the shape of model parameters. The shape of the weight tensor in the checkpoint you are trying to load is torch.Size([64, 1, 3, 3, 3]), whereas the shape expected by your current model definition in your sample.py is torch.Size([64, 3, 3, 3, 3]). Please update your inference script 'sample.py' to use the same input channel configuration as your training script.
Hi,
Thank you very much for your codebase, it is very clean! I could successfully train your model on the whole_head dataset that your provided using your 'train.py' script.
Although, running inference ('sample.py') using your pretrained model 'model_128.pt' does work, I do get an error when trying to load a model that resulted from my training.
returns: