jaketae / ensemble-transformers

Ensembling Hugging Face transformers made easy
MIT License
62 stars 5 forks source link

Add model output class #3

Closed jaketae closed 1 year ago

jaketae commented 1 year ago

Context

  1. Previously, each ensemble class had its own forward method. However, this was duplicated code.
  2. Like HF's transformers, it makes a lot more sense to create an interface around an output class, which can hold different fields depending on the modality.
  3. We'd also like to have at least a minimum testing suite.

Solution

  1. Factor out the duplicated code into the base class.
  2. Introduce an EnsembleModelOutput class which can hold varying attributes depending on the modality. Moreover, this class can also handle mean-pooling logic.
  3. Add basic tests for modality checks, ensemble coherence, and device movement.