csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
748 stars 67 forks source link

auraloss.ipynb: KeyError: ["input"] not found in torchaudio.load() line #15

Closed drscotthawley closed 3 years ago

drscotthawley commented 3 years ago

Christian, this is all super cool. I look forward to being able to run this on my new GPU.

When I run the notebook, it dies at the line,

input, sr  = torchaudio.load(self.examples[idx]["input"]

by saying the Key "input" is not found. Looking further up in the code where self.examples.append( is called, it looks like you use the key "input_file" instead of "input". Could this be the source of the problem?

Full log follows:

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Using native 16bit precision.

  | Name    | Type                     | Params
-----------------------------------------------------
0 | l1      | L1Loss                   | 0     
1 | esr     | ESRLoss                  | 0     
2 | dc      | DCLoss                   | 0     
3 | logcosh | LogCoshLoss              | 0     
4 | stft    | STFTLoss                 | 0     
5 | mrstft  | MultiResolutionSTFTLoss  | 0     
6 | rrstft  | RandomResolutionSTFTLoss | 0     
7 | gen     | Sequential               | 2.7 K 
8 | blocks  | ModuleList               | 354 K 
9 | output  | Conv1d                   | 65    
-----------------------------------------------------
357 K     Trainable params
0         Non-trainable params
357 K     Total params
target_138_LA2A_3c__0__0.wav input_138_.wav
target_139_LA2A_3c__0__5.wav input_139_.wav
target_141_LA2A_3c__0__15.wav input_141_.wav
target_142_LA2A_3c__0__20.wav input_142_.wav
target_143_LA2A_3c__0__25.wav input_143_.wav
target_144_LA2A_3c__0__30.wav input_144_.wav
target_145_LA2A_3c__0__35.wav input_145_.wav
target_146_LA2A_3c__0__40.wav input_146_.wav
target_147_LA2A_3c__0__45.wav input_147_.wav
target_149_LA2A_3c__0__55.wav input_149_.wav
target_150_LA2A_3c__0__60.wav input_150_.wav
target_151_LA2A_3c__0__65.wav input_151_.wav
target_152_LA2A_3c__0__70.wav input_152_.wav
target_153_LA2A_3c__0__75.wav input_153_.wav
target_154_LA2A_3c__0__80.wav input_154_.wav
target_155_LA2A_3c__0__85.wav input_155_.wav
target_156_LA2A_3c__0__90.wav input_156_.wav
target_157_LA2A_3c__0__95.wav input_157_.wav
target_158_LA2A_3c__0__100.wav input_158_.wav
target_159_LA2A_3c__1__0.wav input_159_.wav
target_160_LA2A_3c__1__5.wav input_160_.wav
target_162_LA2A_3c__1__15.wav input_162_.wav
target_163_LA2A_3c__1__20.wav input_163_.wav
target_164_LA2A_3c__1__25.wav input_164_.wav
target_165_LA2A_3c__1__30.wav input_165_.wav
target_166_LA2A_3c__1__35.wav input_166_.wav
target_167_LA2A_3c__1__40.wav input_167_.wav
target_168_LA2A_3c__1__45.wav input_168_.wav
target_169_LA2A_3c__1__50.wav input_169_.wav
target_170_LA2A_3c__1__55.wav input_170_.wav
target_171_LA2A_3c__1__60.wav input_171_.wav
target_172_LA2A_3c__1__65.wav input_172_.wav
target_174_LA2A_3c__1__75.wav input_174_.wav
target_175_LA2A_3c__1__80.wav input_175_.wav
target_176_LA2A_3c__1__85.wav input_176_.wav
target_177_LA2A_3c__1__90.wav input_177_.wav
target_178_LA2A_3c__1__95.wav input_178_.wav
target_179_LA2A_3c__1__100.wav input_179_.wav
target_221_LA2A_3c__1__100.wav input_221_.wav
target_222_LA2A_2c__0__0.wav input_222_.wav
target_225_LA2A_2c__0__15.wav input_225_.wav
target_226_LA2A_2c__0__20.wav input_226_.wav
target_228_LA2A_2c__0__30.wav input_228_.wav
target_229_LA2A_2c__0__35.wav input_229_.wav
target_230_LA2A_2c__0__40.wav input_230_.wav
target_233_LA2A_2c__0__55.wav input_233_.wav
target_234_LA2A_2c__0__60.wav input_234_.wav
target_237_LA2A_2c__0__75.wav input_237_.wav
target_238_LA2A_2c__0__80.wav input_238_.wav
target_240_LA2A_2c__0__90.wav input_240_.wav
target_241_LA2A_2c__0__95.wav input_241_.wav
target_242_LA2A_2c__0__100.wav input_242_.wav
target_243_LA2A_2c__1__0.wav input_243_.wav
target_244_LA2A_2c__1__5.wav input_244_.wav
target_246_LA2A_2c__1__15.wav input_246_.wav
target_247_LA2A_2c__1__20.wav input_247_.wav
target_249_LA2A_2c__1__30.wav input_249_.wav
target_250_LA2A_2c__1__35.wav input_250_.wav
target_251_LA2A_2c__1__40.wav input_251_.wav
target_253_LA2A_2c__1__50.wav input_253_.wav
target_254_LA2A_2c__1__55.wav input_254_.wav
target_255_LA2A_2c__1__60.wav input_255_.wav
target_258_LA2A_2c__1__75.wav input_258_.wav
target_261_LA2A_2c__1__90.wav input_261_.wav
target_262_LA2A_2c__1__95.wav input_262_.wav
target_263_LA2A_2c__1__100.wav input_263_.wav
Located 188675 examples totaling 19.5 hr in the train subset.
target_140_LA2A_3c__0__10.wav input_140_.wav
target_148_LA2A_3c__0__50.wav input_148_.wav
target_161_LA2A_3c__1__10.wav input_161_.wav
target_173_LA2A_3c__1__70.wav input_173_.wav
target_223_LA2A_2c__0__5.wav input_223_.wav
target_224_LA2A_2c__0__10.wav input_224_.wav
target_227_LA2A_2c__0__25.wav input_227_.wav
target_231_LA2A_2c__0__45.wav input_231_.wav
target_232_LA2A_2c__0__50.wav input_232_.wav
target_239_LA2A_2c__0__85.wav input_239_.wav
target_245_LA2A_2c__1__10.wav input_245_.wav
target_248_LA2A_2c__1__25.wav input_248_.wav
target_252_LA2A_2c__1__45.wav input_252_.wav
target_257_LA2A_2c__1__70.wav input_257_.wav
target_260_LA2A_2c__1__85.wav input_260_.wav
Located 2450 examples totaling 4.1 hr in the val subset.
Validation sanity check: 0%
0/2 [00:00<?, ?it/s]
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-15-3454d6e526e4> in <module>
     63 
     64 # train!
---> 65 trainer.fit(model, train_dataloader, val_dataloader)

~/env_pymix/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    508         self.call_hook('on_fit_start')
    509 
--> 510         results = self.accelerator_backend.train()
    511         self.accelerator_backend.teardown()
    512 

~/env_pymix/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in train(self)
     55     def train(self):
     56         self.trainer.setup_trainer(self.trainer.model)
---> 57         return self.train_or_test()
     58 
     59     def teardown(self):

~/env_pymix/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in train_or_test(self)
     72         else:
     73             self.trainer.train_loop.setup_training()
---> 74             results = self.trainer.train()
     75         return results
     76 

~/env_pymix/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in train(self)
    530 
    531     def train(self):
--> 532         self.run_sanity_check(self.get_model())
    533 
    534         # set stage for logging

~/env_pymix/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
    729 
    730             # run eval step
--> 731             _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches)
    732 
    733             # allow no returns from eval

~/env_pymix/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, max_batches, on_epoch)
    628             dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
    629 
--> 630             for batch_idx, batch in enumerate(dataloader):
    631                 if batch is None:
    632                     continue

~/env_pymix/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

~/env_pymix/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    473     def _next_data(self):
    474         index = self._next_index()  # may raise StopIteration
--> 475         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    476         if self._pin_memory:
    477             data = _utils.pin_memory.pin_memory(data)

~/env_pymix/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/env_pymix/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-13-21c59ef52170> in __getitem__(self, idx)
     76         else:
     77           offset = self.examples[idx]["offset"]
---> 78           input, sr  = torchaudio.load(self.examples[idx]["input"], 
     79                                       num_frames=self.length,
     80                                        frame_offset=offset,

KeyError: 'input'
drscotthawley commented 3 years ago

If I take out the "_file" in the key names where self.examples.append() is called, then I don't get the error, and instead the training run proceeds to

Validation sanity check: 0%
0/2 [00:00<?, ?it/s]

...and then it runs out of CUDA memory for me. Only running on a GTX 3080 with 10GB of VRAM. Maybe I can decrease the batch size, or else switch to a different GPU.

Anyway, I suspect that changing the keys from "input_file" to "input" and "target_file" to "target" was the right move.

csteinmetz1 commented 3 years ago

Hey Scott, Thanks again for checking this out!

Unfortunately this notebook is quite out of date with the main codebase which is likely why you are running into these errors. Sorry to throw you on a wild goose chase. I would recommend running the examples/compressor/train.sh script, which has all the hyperparameters we used in our experiments. We used a batch size of 128 here, which requires ~14GB of VRAM, so you will likely need to bump that down in the script.

If you run into any more issues please let me know. Also, we are about to release another repo focused just on modeling the LA-2A with updated models and training code, so keep a look out for that in the next few days.