a-antoniades / Neuroformer

MIT License
30 stars 3 forks source link

Question about modality inputs and predictions #1

Closed colehurwitz closed 7 months ago

colehurwitz commented 9 months ago

First of all, congrats on the cool paper and method!

I was a bit confused about how modalities were inputted to and predicted from the Neuroformer. How does modality prediction work for the Neuroformer? During training, is the Neuroformer trained to reconstruct the modalities that are inputted (multiple heads)? During inference time, is the modality that you want to predict excluded from the inputs and then predicted by the Neuroformer somehow? How are missing modalities handled if I only have a subset of the modalities used for training during inference time? I couldn't figure this out from the code or paper. Any help would be much appreciated!

a-antoniades commented 9 months ago

Thanks for your kind words 😁.

You're right. If you look at the Modalities and Task Configuration portion of the readme.md, you'll see how to specify wether a certain modality should be used as input or to be predicted (set Predict parameter to True/False ).

Predict = False. Neuroformer constructs cross-attention layers for that modality and adds it as input.

Predict = True. Neuroformer constructs projection heads for that modality and optimizes it using the objective of your choice ('regression' or 'classification').

You can jointly pretrain + optimize as many additional modalities as you want. If you alternatively want to first pretrain, and then finetune, you can follow instructions in Finetuning.

Additionally, you can ignore generative pretraining all together and just optimize on the decoding objective for the modalities you want directly (by not setting --resume).

colehurwitz commented 9 months ago

Cool! Thanks for answering :) So if I set predict=True, the modality won't be used an input during training but rather as something that is decoded during training?

a-antoniades commented 9 months ago

Yes! If you're using a custom dataset, just remember to put your data in the correct format as specified in the readme.md, and specify the correct temporal resolution of your bins (see the configs i included for an example) so the dataloader can index your data correctly.

By default, if a holdout set is also included, the trainer will save the weights of the best model for each of the objectives (and all of them combined).