isaaccorley / torchseg

Segmentation models with pretrained backbones. PyTorch.
MIT License
104 stars 8 forks source link

Unexpected keyword argument 'output_stride' when using UnetPlusPlus #14

Closed namKolorfuL closed 8 months ago

namKolorfuL commented 8 months ago

I tried using UnetPlusPlus with the same config as the Unet in readme.md

model = torchseg.UnetPlusPlus(
    "maxvit_small_tf_224",
    in_channels=3,
    classes=2,
    encoder_weights=True,
    encoder_depth=5,
    decoder_channels=(256, 128, 64, 32, 16),
    encoder_params={"img_size": 256}
)

but got following error, can someone help fix this?:

Traceback (most recent call last):
  File "/home/nam/nam/torchseg/test.py", line 49, in <module>
    model = torchseg.UnetPlusPlus(
  File "/home/nam/nam/torchseg/torchseg/decoders/unetplusplus/model.py", line 76, in __init__
    self.encoder = get_encoder(
  File "/home/nam/nam/torchseg/torchseg/encoders/__init__.py", line 65, in get_encoder
    encoder = TimmEncoder(
  File "/home/nam/nam/torchseg/torchseg/encoders/timm.py", line 53, in __init__
    self.model = timm.create_model(name, **params)
  File "/home/nam/anaconda3/envs/testtorchseg/lib/python3.9/site-packages/timm/models/_factory.py", line 117, in create_model
    model = create_fn(
  File "/home/nam/anaconda3/envs/testtorchseg/lib/python3.9/site-packages/timm/models/maxxvit.py", line 2276, in maxvit_small_tf_224
    return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
  File "/home/nam/anaconda3/envs/testtorchseg/lib/python3.9/site-packages/timm/models/maxxvit.py", line 1816, in _create_maxxvit
    return build_model_with_cfg(
  File "/home/nam/anaconda3/envs/testtorchseg/lib/python3.9/site-packages/timm/models/_builder.py", line 400, in build_model_with_cfg
    model = model_cls(cfg=model_cfg, **kwargs)
  File "/home/nam/anaconda3/envs/testtorchseg/lib/python3.9/site-packages/timm/models/maxxvit.py", line 1152, in __init__
    cfg = _overlay_kwargs(cfg, **kwargs)
  File "/home/nam/anaconda3/envs/testtorchseg/lib/python3.9/site-packages/timm/models/maxxvit.py", line 1123, in _overlay_kwargs
    cfg = replace(
  File "/home/nam/anaconda3/envs/testtorchseg/lib/python3.9/dataclasses.py", line 1284, in replace
    return obj.__class__(**changes)
TypeError: __init__() got an unexpected keyword argument 'output_stride'
isaaccorley commented 8 months ago

@namKolorfuL I'll release a fix for this this weekend but in the meantime if you set encoder_output_stride=None in UnetPlusPlus this should work.

The reason is maxvit doesn't necessarily have a stride component. But UnetPlusPlus defaults encoder_output_stride=32 which I need to change.