BiomedSciAI / fuse-med-ml

A python framework accelerating ML based discovery in the medical field by encouraging code reuse. Batteries included :)
Apache License 2.0
137 stars 34 forks source link

Additional model - `CrossAttentionTransformerEncoder` #251

Closed SagiPolaczek closed 1 year ago

SagiPolaczek commented 1 year ago

Ready for initial CR 🥳

Example of use:

torch_model = CrossAttentionTransformer(**model_params)
bb_model = ModelWrapSeqToDict(
    model=torch_model,
    model_inputs=("data.drug.tokenized", "data.target.tokenized"),
    model_outputs=("model.backbone_features",),
    post_forward_processing_function=lambda x: x.mean(dim=1),
)

Also added a simple unit test to (almost) ensure validity over time