Oufattole / meds-torch

MIT License
16 stars 2 forks source link

All collate functions should be the same and static data pre-pending should happen in getitem. #94

Closed mmcdermott closed 1 month ago

mmcdermott commented 1 month ago

This is blocked by:

mmcdermott commented 1 month ago

@Oufattole what is static_mask in the output of your dataset function?

mmcdermott commented 1 month ago

And @Oufattole, what happens to the static data if do_prepend_static_data is False?

Oufattole commented 1 month ago

static_mask is only used by the triplet encoder when you prepend static data. It masks out the time embeddings in the triplet encoder on this line.

And static data is ignored if do_prepend_static_data is False

mmcdermott commented 1 month ago

static_mask is only used by the triplet encoder when you prepend static data. It masks out the time embeddings in the triplet encoder on this line.

And static data is ignored if do_prepend_static_data is False

Is there an issue tracking the fact that static data shouldn't be ignored in that case? Or should we eliminate that as an option all together and just always prepend for now? Or what?

Oufattole commented 1 month ago

The two approaches we have are essentially:

  1. Early fusion: Prepending static data (using static_mask in triplet encoder)
  2. Late fusion: Adding static data to the batch for handling by input_encoder/sequence_model

I'd suggest keeping both options but renaming them to "early_static_fusion" and "late_static_fusion" to make the architectural choice clearer. Also as an fyi, I don't currently actually do late fusion, I just ignore the static tokens rn.

mmcdermott commented 1 month ago

Right now, I don't think you even pass the static data into the batch for 2, actually. But we can ignore this for now.