jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.26k stars 186 forks source link

Tabular: typo in attribute name #204

Closed phantom-duck closed 2 months ago

phantom-duck commented 2 months ago

https://github.com/jrzaurin/pytorch-widedeep/blob/b1bf2faf0f134d8258f6a56c6857d14c24f0a852/pytorch_widedeep/models/tabular/transformers/saint.py#L291

https://github.com/jrzaurin/pytorch-widedeep/blob/b1bf2faf0f134d8258f6a56c6857d14c24f0a852/pytorch_widedeep/models/tabular/resnet/tab_resnet.py#L403

I believe the attribute intended to be used here is self.mlp_hidden_dims.

A minimal reproducible example is the following:

from pytorch_widedeep.models import SAINT

SAINT(
    column_idx={"a": 0},
    continuous_cols=["a"],
    mlp_hidden_dims=[8, 4],
    mlp_linear_first=True,
)

which fails with the error AttributeError: 'SAINT' object has no attribute 'mlp_hidden_dim'

jrzaurin commented 2 months ago

@phantom-duck

s***t, thanks! :) I will release a fix asap

jrzaurin commented 2 months ago

@phantom-duck

it is fixed and merged to master, so if you install from git you should not have a problem

pip install git+https://github.com/jrzaurin/pytorch-widedeep.git

I will publish to pypi asap

jrzaurin commented 2 months ago

@phantom-duck

it is now published to pypi. THANKS for opening the issue