slp-rl / aero

This repo contains the official PyTorch implementation of "Audio Super Resolution in the Spectral Domain" (ICASSP 2023)
MIT License
190 stars 24 forks source link

Suggestion: Allow bfloat16 use to improve speed/memory usage #25

Open pokepress opened 4 months ago

pokepress commented 4 months ago

Hi, as you may have noticed, I've been developing a fork of this in an attempt to repurpose to train models for upscaling AM and FM radio recordings. In any event, while researching possible ways to improve speed and memory consumption, I started looking at 16-bit floating point formats. The standard float16 doesn't appear to have enough range to be useful for this project (I ended up with nan values relatively quickly), but bfloat16 (which uses the same number of exponent bits as float32) seems to work quite well and does speed up training a bit and has a significant impact on memory usage. You can see an example of its implementation in a recent commit I made. Note that I did have to restructure the code so loss calculation is done using standard 32-bit float values (as recommended by pytorch).

m-mandel commented 4 months ago

Thank you for your contribution!