TideDancer / interspeech21_emotion

99 stars 20 forks source link

how to incorporate not only wavform as input #13

Closed gitgaviny closed 2 years ago

gitgaviny commented 2 years ago

Dear @TideDancer ,

Thank you very much for providing the program of interspeech paper!

I'm a bignner of hugging face training api, and I find it difficult to control the input of the model. For example, if I want an additional branch of feature extractor with MFCC as input rather than raw wavform (or using mannual trasncription or video input as multimodal system), is that possible to train it together with this wav2vec2 model? Since the length of labels are consistant, we can simply use [:-1] or [:-2] to build multi-task learning model. But this is not applicable for varible length input features. Hope to get your reply.

gitgaviny commented 2 years ago

Or could you please tell me how to save features of the hidden states (like torchvision.models._utils.IntermediateLayerGetter did) so that I can use them as the input of another costum layer. Directly use torch.save can only processed batch to batch, is there any api can help me with that? I tried to load the pre-trained model (pytorch_model.bin) to my own model and train without transformer.trainer, so that I can use costum input, but the model tend to classify all inputs (IEMOCAP using raw waveform) into one class. So I'm interested how to combine this model with my own dataset class.

TideDancer commented 2 years ago

Hello @gavinyuan1 , thanks for your interest.

  1. Using other features like MFCC: I think the best way is to look at the Huggingface datasets package: https://huggingface.co/docs/datasets/audio_process, and https://huggingface.co/docs/datasets/process. The input data extractor and batch processor are defined there. You can build a customized one to process the data and feed into the model
  2. Obtain intermediate layer's output / hidden features: As far as I know, the easiest way is to use pytorch hooks: https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html. Basically you can define a forward_hook to extract any variable's values at run time. There are also tutorials like https://medium.com/analytics-vidhya/pytorch-hooks-5909c7636fb. Hope these addresses your questions.
gitgaviny commented 2 years ago

Thank you for your reply! I will close this issue.