KrishnaDN / Attentive-Statistics-Pooling-for-Deep-Speaker-Embedding

Implementation of the paper "Attentive Statistics Pooling for Deep Speaker Embedding" in Pytorch
40 stars 10 forks source link

Inference with 1 Sample error #3

Open cademack opened 2 years ago

cademack commented 2 years ago

Line 26 of Attention_Pooling.py contains: attention_weights = F.tanh(lin_out.bmm(v_view).squeeze())

This unspecified squeeze will eliminate the batch # dimension if performing inference on a single input since it would be 1. I'm proposing this line should be changed to:

attention_weights = F.tanh(lin_out.bmm(v_view).squeeze(2))

Simply specifying the dimension of this squeeze should fix the problem :).

XiaoshanHsj commented 3 months ago

and i think the correct code should be attention_weights = F.tanh(lin_out).bmm(v_view).squeeze()