SHI-Labs / Compact-Transformers

Escaping the Big Data Paradigm with Compact Transformers, 2021 (Train your Vision Transformers in 30 mins on CIFAR-10 with a single GPU!)
https://arxiv.org/abs/2104.05704
Apache License 2.0
495 stars 77 forks source link

Trouble with model function call in examples/main.py for CIFAR10 #82

Open BKJackson opened 7 months ago

BKJackson commented 7 months ago

It seems the models function call in examples/main.py is failing with the error message:

Traceback (most recent call last):
  File "/mnt/code/Compact_Transformers/examples/main.py", line 279, in <module>
    main()
  File "/mnt/code/Compact_Transformers/examples/main.py", line 127, in main
    model = models.__dict__[args.model]('', False, False, img_size=img_size,
TypeError: cct_2_3x2_32() got multiple values for argument 'img_size'

I'm having trouble identifying how the function parameters need to be changed.

Here's the function call inside the main.py file:

 model = models.__dict__[args.model]('', False, False, img_size=img_size,
                                        num_classes=num_classes,
                                        positional_embedding=args.positional_embedding,
                                        n_conv_layers=args.conv_layers,
                                        kernel_size=args.conv_size,
                                        patch_size=args.patch_size)

I'm calling the examples/main.py script with the following python call:

!python examples/main.py \
       --model cct_2_3x2_32 \
       --conv-size 3 \
       --conv-layers 2 \
       --print-freq -1 \
       --epochs 30 \
       --workers 2 \
       ../cifar10

Any suggestions on how to fix this?

stevenwalton commented 7 months ago

Hi, thanks for bringing this up. There is an error in the code. I traced the issue with pdb and found that if you look at kwargs here you'll see that there are redundant arguments. This is an annoying thing to fix and for now, consider all the model arguments broken.

There's a few pretty trivial fixes though. The easiest imo is changing the signature in the call to the main.py file. If you noticed above, you'll see that there are hardcoded values in the cct_2 call, which is where we get the redundant values. We can remove kernel_size and n_conv_layers. But we will also want to make sure pretrained and progress are keyword arguments. Removing the empty string is optional but I don't know why it is there. So use this.

    model = models.__dict__[args.model](pretrained=False, progress=False,
                                        img_size=img_size,
                                        num_classes=num_classes,
                                        positional_embedding=args.positional_embedding,
                                        patch_size=args.patch_size)

Unfortunately, this error doesn't look to be in isolation. I think it is best to call the main CCT class or the constructor function _cct

Sorry about this, we will resolve this issue.

stevenwalton commented 7 months ago

@alihassanijr I believe this was introduced in 5b21843. What was the reasoning behind this? Same issue is going to exist in the evaluate example.

I'm not sure what the best solution is, but I think all the model registries are poorly constructed. I think the default values should be placed in the function signature, and not in the function call. This will avoid the redundant issues and allow a bit more flexibility if someone chooses to change an argument. But this means touching all cct_*_** and cct_* functions (everything after line 120...).

Was there some better logic that you were trying to implement that I'm missing or was this a mistake?

alihassanijr commented 7 months ago

@stevenwalton Thanks for looking into this. You're right, the model signatures don't match. It's been a long time since I looked at this, but I think you're right in that the commit you referenced should be reverted. I think someone might have tried evaluating on cct_2 directly, and failed because of the mismatched signatures between cct_* and cct_*_**.

alihassanijr commented 7 months ago

Yup, got it. It was #68 addressing #67. We should be fine if that's reverted.