Dylan-H-Wang / SLF-RPM

Official PyTorch implementation of AAAI-22: Self-supervised Representation Learning Framework for Remote PhysiologicalMeasurement using Spatiotemporal Augmentation Loss (SLF-RPM)
https://arxiv.org/abs/2107.07695
Apache License 2.0
38 stars 2 forks source link

Questions about train.py and test.py #13

Open BugMaker2002 opened 8 months ago

BugMaker2002 commented 8 months ago

Hello, can you tell me what this code does in the test.py file?(您好,请问您能否告诉我在test.py文件当中这段代码的作用是什么?) image And why are the model architectures used in train.py and test.py different? In train.py, you used a network architecture of ndim=2048, but in test.py, you seem to have changed to a network architecture of n_class=1, so if you load the pre-trained model directly, you will get an error because the dimensions of the parameters do not match. I think the normal approach would be to use the same network architecture on the training set and the validation set, the model on the validation set loads the trained model weights directly on the training set, and then we train an MLP with nn.Linear(2048, 1), but why don't you do that?(而且为什么在train.py和test.py当中使用的模型架构也不同?在train.py文件当中,您使用的是ndim=2048的网络架构,但是在test.py文件当中,您似乎又换成了n_class=1的网络架构,这样的话如果直接加载预训练模型是会报错的,因为参数的维度对应不上。我觉得正常的做法应该是:在训练集和验证集上使用相同的网络架构,验证集上的模型直接载入训练集训练好的模型权重,然后我们再去训练一个nn.Linear(2048, 1)的MLP,但是为什么您不选择这么做呢?) image

Dylan-H-Wang commented 8 months ago

Since the model was trained using DDP, the model is wrapped with module. We need to remove this prefix before loading the pre-trained model weights.

The difference between train and test model is a common practice in SSL (see SimCLR or MoCo for details). Their backbones are the same and the subsequent MLP part is different. That is why we discard encoder_q.fc when loading the pre-trained weights. For fine-tuning the model on downstream tasks, we discard the MLP after backbone and append a new MLP and train it on the dataset.

BugMaker2002 commented 8 months ago

But I printed out all the keys in state_dict and found that there was no part beginning with "module", that is, the first if statement: if k.S. tartswith("module.encoder_q") and not k.s. tartswith("module.encoder_q.fc"): will not be executed, does that mean this statement is redundant? Also, what does DDP stand for?Thank you very much for answering my questions image

Dylan-H-Wang commented 8 months ago

It is not redundant, but useful to have a sanity check.

You can google by yourself about PyTorch DDP.

BugMaker2002 commented 8 months ago

Thank you very much