Previously, each ensemble class had its own forward method. However, this was duplicated code.
Like HF's transformers, it makes a lot more sense to create an interface around an output class, which can hold different fields depending on the modality.
We'd also like to have at least a minimum testing suite.
Solution
Factor out the duplicated code into the base class.
Introduce an EnsembleModelOutput class which can hold varying attributes depending on the modality. Moreover, this class can also handle mean-pooling logic.
Add basic tests for modality checks, ensemble coherence, and device movement.
Context
Solution
EnsembleModelOutput
class which can hold varying attributes depending on the modality. Moreover, this class can also handle mean-pooling logic.