wenet-e2e / wenet

Production First and Production Ready End-to-End Speech Recognition Toolkit
https://wenet-e2e.github.io/wenet/
Apache License 2.0
4.08k stars 1.07k forks source link

update torch to 2.3.0+cu121, torchaudio fail in func tar_file_and_group of wenet/dataset/datapipes.py #2531

Closed housebaby closed 2 months ago

housebaby commented 4 months ago

image

image

During the dataset loading before training , it failed. But when I put the loading script in a single file like test.py , the tar file can be successfully parsed and load:

`#!/bin/python import io import tarfile import logging import torchaudio

ifile="data//train_v17/tr/data.list" # tar list including line like this /data/private/data/shard/train_v17/train_v17/tr/shards_000002632.tar f=open(ifile,'r') line = f.readline() j=0 while line:

stream=tarfile.open("/data/private/data/shard/train_v17/train_v17/tr/shards_000002632.tar",mode="r|*")

stream=tarfile.open(line.strip(),mode="r|*")
AUDIO_FORMAT_SETS = set(['WAV', 'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
i=0
prev_prefix = None
valid = True
for tarinfo in stream:
    if i > 0: 
        break
    name = tarinfo.name
    pos = name.rfind('.')
    example = {}
    print(i, tarinfo)
    assert pos > 0
    prefix, postfix = name[:pos], name[pos + 1:]
    if prev_prefix is not None and prefix != prev_prefix:
        example['key'] = prev_prefix
        if valid:
            print(example['key'])
            #yield example
        example = {}
        valid = True
    with stream.extractfile(tarinfo) as file_obj:
        print(tarinfo,file_obj,postfix)
        try:
            if postfix == 'txt':
                example['txt'] = file_obj.read().decode('utf8').strip()
                print(example['txt'])
            elif postfix in AUDIO_FORMAT_SETS:
                if 0:
                    content=file_obj.read()
                    f=open("tmp.wav",'wb')
                    f.write(content)
                    f.close()
                    waveform, sample_rate = torchaudio.load("tmp.wav")
                else:
                    waveform, sample_rate = torchaudio.load(io.BytesIO(file_obj.read()))
                    i = i + 1
                    j = j + 1
                    print(sample_rate,waveform.shape)
                example['wav'] = waveform
                example['sample_rate'] = sample_rate
            else:
                example[postfix] = file_obj.read()
        except Exception as ex:
            valid = False
            logging.warning('error to parse {}'.format(name))
            print('fail to parse {}'.format(name))
    prev_prefix = prefix
if prev_prefix is not None:
    example['key'] = prev_prefix
    #yield example
    print(example['key'])
line = f.readline()
stream.close()

print(j) ` it worked , and the torchaudio loaded successfully image

srdfjy commented 4 months ago

Is the error read using HTTP + shard? Also, I haven't seen your exception information.