google-research / head2toe

Apache License 2.0
81 stars 13 forks source link

Could you please share your model conversion file? #8

Closed nuaajeff closed 2 years ago

nuaajeff commented 2 years ago

Thanks for your great work and you're so kind to share the code. When I ran the code, I found if I wanted to use the custom model or different pre-training parameters (e.g., ImageNet-21k pretrained ViT-16/B), I needed to convert the model file from jax to tensorflow. And the process is complicated and frustrating. Could you please share your model conversion file (especially ViT model), so that this process would not be that hard? By the way, I found there are 4 extra group of features, i.e, {'cls_embedded', 'encoded_sequence', 'position_embedded_input', 'root_output_with_cls'}. And in the appendix it was wrote "Additionally we use patched and embedded image (i.e. tokenized image input), pre-logits (final CLS-embedding) and logits." What are 'encoded_sequence', 'cls_embedded', 'root_output_with_cls' represent for? I notice all of them have 768 dims, but logits seem should have 1000 dims. Thank you in advance!

nuaajeff commented 2 years ago

And I also would like to know the reason that the "with tf.device('/CPU:0'):" appears in the function "_optimize_finetune" and "_calculate_scores", because after I replaced the model with my own ViT-16/B model, I met the error that

OP_REQUIRES failed at xla_ops.cc:287 : INVALID_ARGUMENT: Trying to access resource Resource-203-at-0x560cd7f8d710 located in device /job:localhost/replica:0/task:0/device:GPU:0 from device /job:localhost/replica:0/task:0/device:CPU:0

around those statements. The program no longer continued to run. And after I replaced "with tf.device('/CPU:0'):" with "with tf.device('/GPU:0'):", the program continues.

Thank you in advance again! : )

evcu commented 2 years ago

we use

  1. 'root_output_with_cls' --> input patch embeddings with cls token added.
  2. 'position_embedded_input' -> input of the transformer encoder with the position information added.
  3. 'encoded_sequence' -> Output of the transformer encoder after layer norm (includes all tokens)
  4. 'cls_embedded' -> Only class embedding of the encoded_sequence (prelogits)

I don't think we have logits in the checkpoint we open-sourced. I think that must be a mistake we wrote so, I'll fix that if we upload a new version. Though if I remember correctly having logits or not didn't change the results much for the ResNet.

We use CPU to store embedding due to memory. In our setup we had more CPU memory than GPU memory; the reason you are getting an error could due to limited CPU memory. If GPU works, you can use that.

nuaajeff commented 2 years ago
self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

This is a pytorch style of MLP used in transformer. As the paper stated, there are two sets of features are extracted from the MLP, (3-4) features after MLP layers (and after gelu activation function). Could you please point out which two sets of features from the layers are extracted? Thanks a lot!

evcu commented 2 years ago

After GELU and after Linear (both before the dropout).