Open nayanjha16 opened 1 year ago
I got the same problem as well when I use AISHELL3, the size of x changes as well, I would ran into loss mismatch when the size of x is [batch, 1]. I think this is something wrong with the input. Hope someone could find out why.
In the file modules.py line x = x + pitch_embedding ( line no 131) the dimension of pitch_embedding is different from the dimetion of x
The same can be verified from the following logs : - printing the size of x torch.Size([16, 25, 256]) printing the type of pitch_embedding <class 'torch.Tensor'> printing the size of pitch_embedding 16 printing the size of pitch_embedding torch.Size([16, 64, 256])
On investigating this issue a bit further in the get_pitch_embedding() where we get the pitch embedding I found the following:- 1) The dimension of prediction being generated in following line 'prediction = self.pitch_predictor(x, mask)'
matches the dimension of x . It is only in the line embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) that the dimension of embedding which later becomes pitch_embedding mismatches from the dimension of x.
The log for printing the dimentions are as under : from within get_pitch_embedding fn () the size of x is torch.Size([16, 25, 256]) the shape of prediction is torch.Size([16, 25]) printing the shape of target torch.Size([16, 64]) printing the shape after bucketization
The code snippet for the get_pitch_embedding() from model/modules.py is as under
def get_pitch_embedding(self, x, target, mask, control): print("from within get_pitch_embedding fn () ") print("the size of x is {0} ".format(x.size())) prediction = self.pitch_predictor(x, mask) print("the shape of prediction is {0}".format(prediction.size())) if target is not None: print("printing the shape of target {0}".format(target.size())) print("printing the shape after bucketization") print(torch.bucketize(target, self.pitch_bins).shape) embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) print('------if--------') else: print('-----else-------') prediction = prediction * control embedding = self.pitch_embedding( torch.bucketize(prediction, self.pitch_bins) ) print(prediction.shape, embedding.shape) return prediction, embedding
Note : The language being used is Hindi. I have updated the phone set in text/symbols.py accordingly
Note : I feel there is a problem with target generation. Can someone please point out where is this value being generated ?