Closed philgzl closed 1 month ago
Sorry, I forgot the memory usage. But, you can use --trainer.accumulate_grad_batches
to reduce the memory usage without reducing the batch size.
I am trying to port the model to a different codebase that does not run Lightning so I cannot use that option.
Can you just maybe share what GPU model you used to run the experiments and more specifically how big it was? Was a batch size of 4 chosen because you where hitting the GPU memory limit or could you increase it?
Also can you confirm you are not using half-precision?
I am trying to port the model to a different codebase that does not run Lightning so I cannot use that option.
You can directly set --trainer.accumulate_grad_batches
in lightning.
Can you just maybe share what GPU model you used to run the experiments and more specifically how big it was?
Information can be found in our papers.
Was a batch size of 4 chosen because you where hitting the GPU memory limit or could you increase it?
It's not. Batch size of 4 is chosen for the balance between training speed and convergence speed in terms of epoch.
Also can you confirm you are not using half-precision?
Half-precision and full-precision are both used in our paper. For the comparison with other methods, we generally use full-precision; for ablation experiments, we use half-precision. Full-precision generally produces better results.
You can directly set
--trainer.accumulate_grad_batches
in lightning.
As I tried to explain I cannot use Lightning so this is unfortunately not an option for me.
Information can be found in our papers.
This information is not in the SpatialNet and online SpatialNet papers.
For future readers, I managed to train the model at 16 kHz with the following settings:
The reported GPU memory usage is ~30GB. One epoch takes ~20 min on a dataset of 1500 batches (~6 h 40 m).
Hi and thanks again for this cool project.
Could you provide some insight on the GPU memory requirements for training the different configurations of the online SpatialNet (MHSA vs. Retention vs. Mamba and 4-s vs. 32-s utterances)? I am currently facing GPU out-of-memory errors on a A100 40GB GPU when using Retention, 4-s utterances and a batch size of 4 utterances. My sampling rate is 16 kHz as opposed to 8 kHz in your paper but I doubled the STFT window and hop lengths so the dimensionality along the time axis should be the same.