IcarusWizard / MAE

PyTorch implementation of Masked Autoencoder
MIT License
232 stars 46 forks source link

Ask some questions #19

Closed DM0815 closed 1 year ago

DM0815 commented 1 year ago

Dear Dr.Zhang, If there are no labels in my trainset.Can I remove the cls_token? If do, will it influence the training process?

IcarusWizard commented 1 year ago

Hi, the cls_token represents the global information of the input image. It is called cls_token (aka class token) since it was first used to do classification.

For the MAE pretraining, I don't think the cls_token is that important. Also, the MAE algorithm itself doesn't require you to have labels in your dataset. Remove it just make it different with the most ViT models in the field.

For the classifier training, a cls_token is recommended to have. If you remove that, you need to different way to get the representation for the classifier head. Maybe a mean-pooling of the image tokens can do the job?

DM0815 commented 1 year ago

Thanks for your information. To be honest, in my opinion ,I just regard MAE as an embedding method for my tasks. I will add a linear layer at the end of model to output the final embedding, then I use KNN algorithms with the final embedding to get a adjacency matrix followed by clustering( using some clustering algorithms such as Louvain, Leiden...) But I have a question, if I get the classifier results from MAE model, Will this result be different from the one I got through the method above?

BTW, as for my input data. just two same shape matrix( rows represent samples, columns represent features), I regard two matrixs as two channels of pictures, but the ptaches are one dim in mydataset. Do you have some suggestions or hints for me? @IcarusWizard Thanks!

DM0815 commented 1 year ago

In addition, my goal is to get the labels of input example and visualize it by umap or TSNE. I think the features used for visualization are very important, otherwise the distinction between each community is not going to be very clear in the results of the visualization.

IcarusWizard commented 1 year ago

Regarding your setting, if I am understanding correctly, you are training with some tabular data with a 2d vector feature for each entry. I think the tabular setting will generaly be a challenge for MAE, since one core assumption is that the information in the input data (images) is redundant. That is why do can reconstruct the data with a high masking rate of 75%. For tabular data, I presume it is more independent between each entry? In this case, you should not use high masking rate in general. Or maybe your data is intercorrelated? Then you can try with higher masking rate.

I cannot tell which option is better (using the cls_token as embedding or mean pooling the tokens) for your dataset. If I were you, I will just try both and see which one works better.

DM0815 commented 1 year ago

Thank you so much for your reply, I will do it. The data is interconnected in a sense. Thanks.

DM0815 commented 1 year ago

Excuse me, I'm sorry to bother you. I have several questions want to ask you.Because I add a layer in the end of decoder, and use the ouputs to cluster and visualize, but the cluster figures are not well.Many points are not clustered together. I suspect the set of layers of encoder and decoder and head numbers influence results. To be honest, I'm not sure how to decide the models's paras such as layer numbers and head numbers. Maybe should I try different combination? Second , I see run 20 epoches,the loss is decreased to 0.02, maybe overfitting, but when I adjust the maskratio and dropout,the effect on the results is not very clear. In addition, the cluster figures are not well. Is it possible that I extracted the output in the wrong place? Added codes attached 1698839980498

IcarusWizard commented 1 year ago

I am not sure what you intend to do. Are you looking for some representation vector for your data? Then you should take the output of the encoder as the representation instead of the decoder.

DM0815 commented 1 year ago

Get it. Thanks for your reply.