YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.07k stars 205 forks source link

Running on multiple GPUs / Adding a new metric / Using AST as Feature Extractor #22

Closed jvel07 closed 2 years ago

jvel07 commented 2 years ago

Hi, nice work!

Was wondering whether there exists a parameter for specifying the number of GPUs to use for training?

YuanGongND commented 2 years ago

Hi there,

We use torch.nn.dataparallel, which will use all GPUs of your environment. We use SLURM, so it is easy to set the gres number for each job (see our run.sh).

If you do not use SLURM, you can do the following:

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,2,3,4"

Remember to put it at the beginning of your script before import torch.

-Yuan

jvel07 commented 2 years ago

Thanks. I added the snippet code into the run.py and ast_models.py (before importing torch). I have actually 4 x 2080 Ti devices available but I am getting "out of memory error".

I am using my own dataset that consists of 225 wavs only. I intend to use the AST pre-trained model for downstream tasks (already followed all the steps described in the readme).

Here's the error trace:

Traceback (most recent call last):
  File "../../src/run.py", line 111, in <module>
    train(audio_model, train_loader, val_loader, args)

  File "/home/egasj/PycharmProjects/ast/src/traintest.py", line 128, in train
    audio_output = audio_model(audio_input)

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 157, in forward
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 174, in scatter
    return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 44, in scatter_kwargs
    inputs = scatter(inputs, target_gpus, dim) if inputs else []

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 36, in scatter
    res = scatter_map(inputs)

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 23, in scatter_map
    return list(zip(*map(scatter_map, obj)))

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 19, in scatter_map
    return Scatter.apply(target_gpus, None, dim, obj)

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/_functions.py", line 93, in forward
    outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)

  File "/home/egasj/.conda/envs/ast/lib/python3.7/site-packages/torch/nn/parallel/comm.py", line 189, in scatter
    return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
RuntimeError: CUDA error: out of memory
YuanGongND commented 2 years ago

Thanks for the information. There are two things:

  1. If all your 4 2080Ti are used - you can use nvidia-smi to check the GPU usage.
  2. If 4 2080Ti is sufficient for your task. I think the total number of samples of your dataset is less important (although I am wondering if 255 samples are a too small dataset). Things really matter are 1) mini-batch size, the GPU memory usage has a linear relationship with the batch size, so you can try a smaller batch size, note for smaller batch size, you should also use smaller learning rate, or using gradient accumulation. 2) the input sequence length, the GPU memory usage grows quadratically with the input sequence length, one simple way to decrease the sequence length is using a larger stride for patch splitting, you can set the tstride and fstride up to 16, which would roughly only use 1/4 memory compared with the setting of tstride=fstride=10. The performance will have a slight drop for larger strides. For your reference, we use 4X12GB GPUs, when the input length is 1024, the stride is 10, we are able to use a batch size of 12.

-Yuan

jvel07 commented 2 years ago

Thanks, I will try that out.

Got two questions tho:

  1. I considered using a larger dataset. Now I got about 9k utterances for train, dev, and test, respectively. Each utterance varies from 1 to 4 seconds of duration. This dataset is for binary classification and it's unbalanced. So I would need to use UAR (Unweighted Average Recall) for evaluating the model, which is nothing but: sklearn.metrics.recall_score(y_true, y_pred, labels=[1, 0], average='macro') (If needed, I could also make a sklearn 'scorer' function of this one). The question is, where can I add/append this new metric into your code so it can be used instead of, e.g., 'acc' or 'mAP'?

  2. Can AST be used as a feature extractor? I.e., how can one use your repo in order to extract embeddings from utterances?

PD: since the thread of this conversation changed a bit as we went further, I modified the title of the issue for future users' needs.

YuanGongND commented 2 years ago

Hi there,

For the first question. Our metric code is in here. For your dataset, I think you can set target_length as 400, probably also need to upsample your minority class, or use a class-balancing loss. For binary classification, you need to reconsider mixup strategy. In general, I think you need to search the hyperparameters for the task.

For the second question. Yes - it can be used as a feature extractor, but we always find a performance improvement for the fine-tuning setting (i.e., not freeze any layer and train the entire network).

-Yuan

jvel07 commented 2 years ago

Thanks for your answer. Regarding the feature extractor thing, what would be the way to get embeddings of my dataset using your implementation? My aim would be to later use these embeddings to train, e.g., an SVM, which presumably could be a better fit for tasks with smaller datasets.

YuanGongND commented 2 years ago

To get the embedding, you can change the forward function in ast_model at here, basically, you can just return x at this line and skip line 181.