fschmid56 / EfficientAT

This repository aims at providing efficient CNNs for Audio Tagging. We provide AudioSet pre-trained models ready for downstream training and extraction of audio embeddings.
MIT License
218 stars 41 forks source link

How to use mn10_as as a pre-trained model, and fine-tune on a new dataset #14

Closed jmren168 closed 10 months ago

jmren168 commented 1 year ago

Hi,

I'm trying to use mn10_as as a pre-trained model, and want to fine-tune it to fit my collected dataset (3 classes, 50 10-second clips each class, sampling rate: 16KHz).

I prepraed my dataset and made a format like DCASE20, but only filename and scene_label are used. Then I modified ex_dcase20.py to ex_my_dataset.py. Follows point out where I modified, but errors occured, and hope you could give me some hits. Many thanks.

In ex_my_dataset.py,

def train(args):
    # Train Models for Acoustic Scene Classification

    # logging is done using wandb
    wandb.init(
        project="my_dataset",
        notes="Fine-tune Models for Acoustic Scene Classification.",
        tags=[ "Acoustic Scene Classification", "Fine-Tuning"],
        config=args,
        name=args.experiment_name
    )

    device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')

    # model to preprocess waveform into mel spectrograms
    mel = AugmentMelSTFT(n_mels=args.n_mels,
                         sr=args.resample_rate,
                         win_length=args.window_size,
                         hopsize=args.hop_size,
                         n_fft=args.n_fft,
                         freqm=args.freqm,
                         timem=args.timem,
                         fmin=args.fmin,
                         fmax=args.fmax,
                         fmin_aug_range=args.fmin_aug_range,
                         fmax_aug_range=args.fmax_aug_range
                         )
    mel.to(device)

    # load prediction model
    pretrained_name = args.pretrained_name
    if pretrained_name:
        model = get_mobilenet(width_mult=NAME_TO_WIDTH(pretrained_name), pretrained_name=pretrained_name,
                              head_type=args.head_type, se_dims=args.se_dims, num_classes=3)
    else:
        model = get_mobilenet(width_mult=args.model_width, head_type=args.head_type, se_dims=args.se_dims,
                              num_classes=3)

parser.add_argument('--pretrained_name', type=str, default='mn10_as')

In datasets/my_dataset.py

sr=16000 resample_rate=sr

dataset_config = {
    "dataset_name": "my_dataset",
    "meta_csv": os.path.join(dataset_dir, "meta.csv"),
    "train_files_csv": os.path.join(dataset_dir, "evaluation_setup", "fold1_train.csv"),
    "test_files_csv": os.path.join(dataset_dir, "evaluation_setup", "fold1_evaluate.csv")

#
class BasicDCASE22Dataset(TorchDataset):

    def __init__(self, meta_csv, sr=sr, cache_path=None):
        """
        @param meta_csv: meta csv file for the dataset
        @param sr: specify sampling rate
        @param sr: specify cache path to store resampled waveforms
        return: waveform, name of the file, label, device and cities
        """
        df = pd.read_csv(meta_csv, sep="\t")
        le = preprocessing.LabelEncoder()
        self.labels = torch.from_numpy(le.fit_transform(df[['scene_label']].values.reshape(-1)))
        **#self.devices = le.fit_transform(df[['source_label']].values.reshape(-1))
        #self.cities = le.fit_transform(df['identifier'].apply(lambda loc: loc.split("-")[0]).values.reshape(-1))**
        self.files = df[['filename']].values.reshape(-1)
        self.sr = sr
        if cache_path is not None:
            self.cache_path = os.path.join(cache_path, dataset_config["dataset_name"] + f"_r{self.sr}", "files_cache")
            os.makedirs(self.cache_path, exist_ok=True)
        else:
            self.cache_path = None

    def __getitem__(self, index):
        if self.cache_path:
            cpath = os.path.join(self.cache_path, str(index) + ".pt")
            try:
                sig = torch.load(cpath)
            except FileNotFoundError:
                sig, _ = librosa.load(os.path.join(dataset_dir, self.files[index]), sr=self.sr, mono=True)
                sig = torch.from_numpy(sig[np.newaxis])
                torch.save(sig, cpath)
        else:
            sig, _ = librosa.load(os.path.join(dataset_dir, self.files[index]), sr=self.sr, mono=True)
            sig = torch.from_numpy(sig[np.newaxis])
        **#return sig, self.labels[index], self.devices[index], self.cities[index]**
        return sig, self.labels[index]

    def __len__(self):
        return len(self.files)

class SimpleSelectionDataset(TorchDataset):
    """A dataset that selects a subsample from a dataset based on a set of sample ids.
        Supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __init__(self, dataset, available_indices):
        """
        @param dataset: dataset to load data from
        @param available_indices: available indices of samples for 'training', 'testing'
        return: x, label, device, city, index
        """
        self.available_indices = available_indices
        self.dataset = dataset

    def __getitem__(self, index):
        #x, label, device, city = self.dataset[self.available_indices[index]]
        x, label = self.dataset[self.available_indices[index]]
        #return x, label, device, city, self.available_indices[index]
        return x, label, self.available_indices[index]

    def __len__(self):
        return len(self.available_indices)

Error messages:

  File C:\Anaconda3\envs\EfficientAT\lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File d:\users\2023-efficientat-main\ex_my_dataset.py:241
    train(args)

  File d:\users\2023-efficientat-main\ex_my_dataset.py:100 in train
    for batch in pbar:

  File C:\Anaconda3\envs\EfficientAT\lib\site-packages\tqdm\std.py:1195 in __iter__
    for obj in iterable:

  File C:\Anaconda3\envs\EfficientAT\lib\site-packages\torch\utils\data\dataloader.py:435 in __iter__
    return self._get_iterator()

  File C:\Anaconda3\envs\EfficientAT\lib\site-packages\torch\utils\data\dataloader.py:381 in _get_iterator
    return _MultiProcessingDataLoaderIter(self)

  File C:\Anaconda3\envs\EfficientAT\lib\site-packages\torch\utils\data\dataloader.py:1034 in __init__
    w.start()

  File C:\Anaconda3\envs\EfficientAT\lib\multiprocessing\process.py:121 in start
    self._popen = self._Popen(self)

  File C:\Anaconda3\envs\EfficientAT\lib\multiprocessing\context.py:224 in _Popen
    return _default_context.get_context().Process._Popen(process_obj)

  File C:\Anaconda3\envs\EfficientAT\lib\multiprocessing\context.py:336 in _Popen
    return Popen(process_obj)

  File C:\Anaconda3\envs\EfficientAT\lib\multiprocessing\popen_spawn_win32.py:93 in __init__
    reduction.dump(process_obj, to_child)

  File C:\Anaconda3\envs\EfficientAT\lib\multiprocessing\reduction.py:60 in dump
    ForkingPickler(file, protocol).dump(obj)

AttributeError: Can't pickle local object 'get_roll_func.<locals>.roll_func'
jmren168 commented 1 year ago

Following a related issue Method 1 in stackoverflow (https://stackoverflow.com/questions/72766345/attributeerror-cant-pickle-local-object-in-multiprocessing), I modified audiodatasets.py as follows, and it works for me.

from torch.utils.data import Dataset
import torch
import numpy as np

# roll waveform (over time)
def roll_func(b):
    shift_range=10000
    axis=1
    x = b[0]
    others = b[1:]
    x = torch.as_tensor(x)
    sf = None
    if sf is None:
        sf = int(np.random.random_integers(-shift_range, shift_range))
    return (x.roll(sf, axis), *others)

def get_roll_func(axis=1, shift=None, shift_range=10000):

    return roll_func

class PreprocessDataset(Dataset):
    """A base preprocessing dataset representing a preprocessing step of a Dataset preprocessed on the fly.
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __init__(self, dataset, preprocessor):
        self.dataset = dataset
        if not callable(preprocessor):
            print("preprocessor: ", preprocessor)
            raise ValueError('preprocessor should be callable')
        self.preprocessor = preprocessor

    def __getitem__(self, index):
        return self.preprocessor(self.dataset[index])

    def __len__(self):
        return len(self.dataset)
jmren168 commented 1 year ago

BTW, it could be better to add wandb.finish() at the end of training/test part such that wandb can upload the results to cloud server. Otherwise, you may need to upload by run the same code again.

fschmid56 commented 1 year ago

Hi, I resolved the issue in audiodatasets.py similar to what you suggested. I used partial(roll_func, axis=axis, shift=shift, shift_range=shift_range) to allow the arguments to be passed to roll_func.

fschmid56 commented 1 year ago

Ad wandb.finish(): the docs say: "This is used when creating multiple runs in the same process. We automatically call this method when your script exits."

Could you explain in which scenarios adding an additional wandb.finish() would be beneficial?