KAIST-VICLab / FMA-Net

[CVPR 2024 Oral] Official repository of FMA-Net
https://kaist-viclab.github.io/fmanet-site/
MIT License
621 stars 43 forks source link

Provide a script to test the model #17

Closed DachunKai closed 4 months ago

DachunKai commented 4 months ago

I use the following script to test the forward function of the model. But it reports size mismatch error.

class Config:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            setattr(self, key, value)

def test_fmanet():
    config_dict = {
        'stage': 2,
        'scale': 4,
        'num_seq': 3,
        'ds_kernel_size': 20,
        'in_channels': 3,
        'dim': 90,
        'ds_kernel_size': 20,
        'us_kernel_size': 5,
        'num_RDB': 12,
        'growth_rate': 18,
        'num_dense_layer': 4,
        'num_flow': 9,
        'num_FRMA': 4,
        'num_transformer_block': 2,
        'num_heads': 6,
        'LayerNorm_type': 'WithBias',
        'ffn_expansion_factor': 2.66,
        'bias': False,
    }
    config = Config(config_dict)

    net = FMANet(
        config
    ).cuda()
    net.eval()

    t = 10 
    input = torch.rand(1, 3, t, 180, 320).cuda()

    macs, _ = profile(model=net, inputs=(input, ), verbose=False)
    params = sum(p.numel() for p in net.parameters())

Error:

F = rearrange(F, '(b n t) c h w -> b (n c) t h w', t=self.num_seq, n=self.num_flow)    # [B, C, T, H, W]
einops.EinopsError:  Error while processing rearrange-reduction pattern "(b n t) c h w -> b (n c) t h w".
 Input tensor shape: torch.Size([90, 10, 180, 320]). Additional info: {'t': 3, 'n': 9}.
 Shape mismatch, can't divide axis of length 90 in chunks of 27
GeunhyukYouk commented 4 months ago

Hi, has this issue been resolved?