jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 190 forks source link

Feature fusion #214

Closed xylovezxy closed 5 months ago

xylovezxy commented 5 months ago

HI,is it difficult to align simple concatenation after extracting visual and textual features. Is it possible to add some feature fusion modules?

jrzaurin commented 5 months ago

it is indeed possible. let me add some guidance tomorrow 🙂

xylovezxy commented 5 months ago

That's really great!

jrzaurin commented 5 months ago

So, the WideDeep class allows to pass a deephead module that can be anything you want. Alternatively, any model component can be anything you want as long as it contains a property that is output_dim.

So you have a number of options, for example:

from pytorch_widedeep.models._base_wd_model_component import BaseWDModelComponent
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import TabPreprocessor # if you need it

X_img = YourImageFeatureExtractor() # I will assume 100 cont features
X_text = YourTextFeatureExtractor() # I will assume 100 cont features

class YourTabularModel(BaseWDModelComponent):
      # Anything you want to do with those two tensors bearing in mind that at the moment (this will change soon) 
      # all models must receive just ONE tensor. So for example:

     def forward(self, X):
          X_img = X[:, 100]
          X_text = X[:, 100:]

         # any fusion method you prefer below

         return out

then , proceed as usual

tab_model = YourTabularModel(...)
my_wd_model = WideDeep(deeptabular=tab_model)

# proceed as usual (remember the trainer needs arrays as inputs)
trainer = Trainer(...)
...

Hope this helps

xylovezxy commented 5 months ago

Thank you

jrzaurin commented 5 months ago

Check release 23. In particular have a look to the README to see how you could fuse components :)