soniajoseph / ViT-Prisma

ViT Prisma is a mechanistic interpretability library for Vision Transformers (ViTs).
Other
165 stars 18 forks source link

Convert my own pre-trained model into HookedViT object #91

Closed Lancelottery closed 4 months ago

Lancelottery commented 4 months ago

Hi Sonia,

Thank you for sharing your work!

I would like to visualize my own model (vit_cxr), which was fine-tuned on the google/vit-base-patch16-224-in21k, using ViT Prisma. The model architecture is ViT-Base, and I have uploaded the fine-tuned model to HuggingFace.

I was playing around with the ViT Prisma Main Demo by importing vit_cxr from HuggingFace as a HookedViT using the following code:

model = HookedViT.from_pretrained(model_name="Lancelottery/cxr-race", is_timm=False)

However, it raised the following value error:

'n_layers': 12, 'd_model': 768, 'd_head': 64, 'model_name': 'Lancelottery/cxr-race', 'n_heads': 12, 'd_mlp': 3072, 'activation_name': 'gelu', 'eps': 1e-12, 'original_architecture': ['ViTForImageClassification'], 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 16, 'image_size': 224, 'n_classes': None, 'n_params': None


ValueError Traceback (most recent call last) /usr/local/lib/python3.10/dist-packages/vit_prisma/prisma_tools/loading_from_pretrained.py in get_pretrained_state_dict(official_model_name, is_timm, is_clip, cfg, hf_model, dtype, **kwargs) 287 ) --> 288 raise ValueError 289

ValueError:

During handling of the above exception, another exception occurred:

ValueError Traceback (most recent call last) 2 frames /usr/local/lib/python3.10/dist-packages/vit_prisma/prisma_tools/loading_from_pretrained.py in get_pretrained_state_dict(official_model_name, is_timm, is_clip, cfg, hf_model, dtype, **kwargs) 295 296 except: --> 297 raise ValueError( 298 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 299 )

ValueError: Loading weights from the architecture is not currently supported: ['ViTForImageClassification'], generated from model name Lancelottery/cxr-race. Feel free to open an issue on GitHub to request this feature.

image

I found that as soon as I set is_timm = False, it will raise a value error when I use an hf_model:

image

I was wondering if the ViT Prisma repository supports loading pretrained models with weights from HuggingFace. If not, is there an alternative way to import my pretrained model, such as using the pytorch_model.bin file?

Thank you in advance for your support and guidance!

soniajoseph commented 4 months ago

Hi Lance!

Thanks for bringing up this issue. We'll look into it as soon as we can.

Best, Sonia ᐧ

On Fri, May 10, 2024 at 12:30 PM Lance Lu @.***> wrote:

Hi Sonia,

Thank you for sharing your work!

I would like to visualize my own model (vit_cxr), which was fine-tuned on the google/vit-base-patch16-224-in21k, using ViT Prisma. The model architecture is ViT-Base, and I have uploaded the fine-tuned model to HuggingFace.

I was playing around with the ViT Prisma Main Demo by importing vit_cxr from HuggingFace as a HookedViT using the following code:

model = HookedViT.from_pretrained(model_name="Lancelottery/cxr-race", is_timm=False)

However, it raised the following value error:

'n_layers': 12, 'd_model': 768, 'd_head': 64, 'model_name': 'Lancelottery/cxr-race', 'n_heads': 12, 'd_mlp': 3072, 'activation_name': 'gelu', 'eps': 1e-12, 'original_architecture': ['ViTForImageClassification'], 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 16, 'image_size': 224, 'n_classes': None, 'n_params': None

ValueError Traceback (most recent call last)

/usr/local/lib/python3.10/dist-packages/vit_prisma/prisma_tools/loading_from_pretrained.py https://localhost:8080/# in get_pretrained_state_dict(official_model_name, is_timm, is_clip, cfg, hf_model, dtype, **kwargs) 287 ) --> 288 raise ValueError 289

ValueError:

During handling of the above exception, another exception occurred:

ValueError Traceback (most recent call last) 2 frames

/usr/local/lib/python3.10/dist-packages/vit_prisma/prisma_tools/loading_from_pretrained.py https://localhost:8080/# in get_pretrained_state_dict(official_model_name, is_timm, is_clip, cfg, hf_model, dtype, **kwargs) 295 296 except: --> 297 raise ValueError( 298 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 299 )

ValueError: Loading weights from the architecture is not currently supported: ['ViTForImageClassification'], generated from model name Lancelottery/cxr-race. Feel free to open an issue on GitHub to request this feature.

image.png (view on web) https://github.com/soniajoseph/ViT-Prisma/assets/77723432/e79cd90e-00b0-480c-bfc6-d3cf0c97833e

I found that as soon as I set is_timm = False, it will raise a value error when I use an hf_model: image.png (view on web) https://github.com/soniajoseph/ViT-Prisma/assets/77723432/301fdcc1-b14b-4437-b1e9-01761d9da399 I was wondering if the ViT Prisma repository supports loading pretrained models with weights from HuggingFace. If not, is there an alternative way to import my pretrained model, such as using the pytorch_model.bin file?

Thank you in advance for your support and guidance!

— Reply to this email directly, view it on GitHub https://github.com/soniajoseph/ViT-Prisma/issues/91, or unsubscribe https://github.com/notifications/unsubscribe-auth/AHBIYMBD2NEQSBMONZNJT3TZBTY35AVCNFSM6AAAAABHQ5DHG2VHI2DSMVQWIX3LMV43ASLTON2WKOZSGI4TAMBUGY2DSNQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

Lancelottery commented 4 months ago

Thank you, looking forward to your reply!

themachinefan commented 4 months ago

This should fix it: https://github.com/soniajoseph/ViT-Prisma/pull/92

Lancelottery commented 4 months ago

This should fix it: #92

Thank you @themachinefan for your excellent contribution! I believe this implementation has the potential to be modified to accept custom classification tasks beyond the standard ImageNet dataset. I'm excited about the possibility of further refinements and adaptations to suit various use cases. Keep up the great work!

soniajoseph commented 4 months ago

Thank you @themachinefan for the excellent and helpful contribution!