csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

Upgrading examples to lightning 2.0 #60

Open mikesol opened 1 year ago

mikesol commented 1 year ago

Hi!

I tried to get the compressor example up and running and, in the process, migrated it as best I can to lightning 2.0. There were several breaking changes, and as I've never used lightning before and am not familiar with auraloss yet, I'm not exactly sure if it worked. But it's training and the loss is going down, so that's a good sign!

The branch is here: https://github.com/mikesol/auraloss/tree/compressor-test.

The command I used locally is:

python examples/compressor/train_comp.py \
   fit \
   --data.root_dir SignalTrain_LA2A_Dataset_1.1 \
   --trainer.max_epochs 20 \
   --model.kernel_size 15 \
   --model.channel_width 32 \
   --model.dilation_growth 2 \
   --data.preload False \
   --data.num_workers 8 \
   --data.shuffle True \
   --data.batch_size 32 \
   --model.nparams 2 \
   --data.length 32768

And the log so far shows:

(.venv) 21:42 meeshkan-abel@Abel:~/mike/auraloss$ python examples/compressor/train_comp.py    fit    --data.root_dir SignalTrain_LA2A_Dataset_1.1    --trainer.max_epochs 20    --model.kernel_size 15    --model.channel_width 32    --model.dilation_growth 2    --data.preload False    --data.num_workers 8    --data.shuffle True    --data.batch_size 32    --model.nparams 2    --data.length 32768
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/fabric/utilities/seed.py:39: UserWarning: No seed found, seed set to 3866398735
  rank_zero_warn(f"No seed found, seed set to {seed}")
Global seed set to 3866398735
Located 94285 examples totaling 19.5 hr in the train subset.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
[rank: 0] Global seed set to 3866398735
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Global seed set to 3866398735
Located 94285 examples totaling 19.5 hr in the train subset.
[rank: 1] Global seed set to 3866398735
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

Located 94285 examples totaling 19.5 hr in the train subset.
Located 94285 examples totaling 19.5 hr in the train subset.
Located 94285 examples totaling 19.5 hr in the train subset.
Located 94285 examples totaling 19.5 hr in the train subset.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name    | Type                     | Params
------------------------------------------------------
0  | l1      | L1Loss                   | 0     
1  | esr     | ESRLoss                  | 0     
2  | dc      | DCLoss                   | 0     
3  | logcosh | LogCoshLoss              | 0     
4  | sisdr   | SISDRLoss                | 0     
5  | stft    | STFTLoss                 | 0     
6  | mrstft  | MultiResolutionSTFTLoss  | 0     
7  | rrstft  | RandomResolutionSTFTLoss | 0     
8  | gen     | Sequential               | 10.5 K
9  | blocks  | ModuleList               | 221 K 
10 | output  | Conv1d                   | 33    
------------------------------------------------------
232 K     Trainable params
0         Non-trainable params
232 K     Total params
0.930     Total estimated model params size (MB)
Sanity Checking DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.30it/s]/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/L1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/ESR', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/DC', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/LogCosh', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/SI-SDR', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/STFT', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
/home/meeshkan-abel/mike/auraloss/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_loss/MRSTFT', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  warning_cache.warn(
Epoch 0:  24%|█████████████████████▎                                                                  | 357/1474 [09:08<28:34,  1.54s/it, v_num=5, train_loss_step=0.691]Epoch 0:  31%|███████████████████████████                                                             | 453/1474 [11:21<25:37,  1.51s/it, v_num=5, train_loss_step=0.858]

If you're interested in updating to lightning 2.0, I'd be happy to help out. The branch definitely isn't in good enough shape yet for a PR, but maybe if you take a look at the diff you'll see what elements needed tweaking and we could take it form there. Thanks & lemme know!