jrzaurin / pytorch-widedeep

A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch
Apache License 2.0
1.3k stars 190 forks source link

Fix dropout layer being created on forward pass #190

Closed BrunoBelucci closed 1 year ago

BrunoBelucci commented 1 year ago

Fix #189. Instantiate dropout layer in __init__ and keep a dropout_p (for probability) attribute that can be passed to F.scaled_dot_product_attention if using flash attention.