TorchDSP / torchsig

TorchSig is an open-source signal processing machine learning toolkit based on the PyTorch data handling pipeline.
MIT License
170 stars 38 forks source link

XCiT not working with different num_classes #222

Closed Benjamin-Etheredge closed 1 month ago

Benjamin-Etheredge commented 1 year ago

Describe the bug Building XCiT with num_classes != 53 causes an attribute error on classifier.

To Reproduce

from torchsig.models.iq_models.xcit.xcit import xcit_nano
xcit_nano(num_classes=10)

Expected behavior It to create a model.

Screenshots

Traceback (most recent call last):
  File "/workspaces/torchsig/failure.py", line 3, in <module>
    xcit_nano(num_classes=10)
  File "/workspaces/torchsig/torchsig/models/iq_models/xcit/xcit.py", line 140, in xcit_nano
    mdl.classifier.in_features,  # type: ignore
    ^^^^^^^^^^^^^^
  File "/home/vscode/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'XCiT' object has no attribute 'classifier'

Desktop (please complete the following information):

Additional context N/A

MutazAbueisheh commented 6 months ago

Hey!

I had the same problem, and solved by replace:

if num_classes != 53:
    mdl.classifier = nn.Linear(
        mdl.classifier.in_features,  # type: ignore
        num_classes,
    )

with:

if num_classes != 53:
    mdl.grouper = nn.Conv1d(
        in_channels=mdl.grouper.in_channels,
        out_channels=num_classes,
        kernel_size=1,
    )

Now it works fine with different number of classes than the original case.

MattCarrickPL commented 1 month ago

Old issue almost 1 year old, and with significant code changes and upcoming signal additions, unclear how relevant this problem still is. Closing out for now.