mlcommons / training

Reference implementations of MLPerf™ training benchmarks
https://mlcommons.org/en/groups/training
Apache License 2.0
1.6k stars 553 forks source link

all of my predict result is <SOS> #588

Closed moseshu closed 1 year ago

moseshu commented 1 year ago

I use your RNNT model to train a model on my custom dataset. But I my inference is sos . sos is my start token. do you known where is the problem?

johntran-nv commented 1 year ago

Hi @moseshu , sorry but this is not enough information for us to understand your question. Could you provide more details? In general I don't think the MLCommons engineers have bandwidth to help debug custom dataset issues, but if it's a problem with the benchmark code, we could help.

moseshu commented 1 year ago

Hi @moseshu , sorry but this is not enough information for us to understand your question. Could you provide more details? In general I don't think the MLCommons engineers have bandwidth to help debug custom dataset issues, but if it's a problem with the benchmark code, we could help. Hello! this below is my code. Just copy it from your repo.I changed the data process, Using the torchaudio.transformers.RNNTLoss,

audio_dim = 128 train_audio_transforms = nn.Sequential( torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=audio_dim), torchaudio.transforms.FrequencyMasking(freq_mask_param=30), torchaudio.transforms.TimeMasking(time_mask_param=100) )

I'm not sure what happened to my processing. Is the trainging process or greedyDecoder process? I use RNNTGreedyDecoder.decode() to decode the result is 0(zero) which is my padd value.

model = RNNT(n_classes=4000, in_feats=128, enc_n_hid=128, enc_pre_rnn_layers=2, enc_post_rnn_layers=3, enc_stack_time_factor=1, enc_dropout=0.1, pred_dropout=0.2, joint_dropout=0.1, pred_n_hid=128, pred_rnn_layers=2, joint_n_hid=256)

my data processing is `def data_collate(data,data_type="train"):

    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    txts = []
    for audio_file,text in data:
        waveform, sample_rate = torchaudio.load(audio_file,normalize=True)
        if data_type == "train":
            spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        else:
            spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        text1 = ["sos " + text]
        label = torch.Tensor(tokenizer.texts_to_sequences(text1)[0])
        spectrograms.append(spec)
        labels.append(label)
        input_lengths.append(spec.shape[0])
        label_lengths.append(len(label))
        txts.append(text)
    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    input_lengths = torch.IntTensor(input_lengths)
    label_lengths = torch.IntTensor(label_lengths)
    return spectrograms, labels.type(torch.IntTensor), input_lengths, label_lengths, txts

`

matthew-frank commented 1 year ago

This doesn't seem to be code from the RNNT reference. Closing. If you have a reproducible bug with the RNNT reference code, please open a new bug with a clear set of steps for reproducing the problem