jbohnslav / deepethogram

Other
98 stars 32 forks source link

Sequence training error: "RuntimeError: Expected floating point type for target with class probabilities, got Long" #91

Open mmh513 opened 2 years ago

mmh513 commented 2 years ago

Hello,

I have inferenced using the feature extractor and I went to train the sequence model with no pre-trained weights and received the following error:

Epoch 0: 0%| | 0/1000 [00:25<?, ?it/s] Traceback (most recent call last): File "C:\Users\mhurl\anaconda3\envs\deg\lib\runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "C:\Users\mhurl\anaconda3\envs\deg\lib\runpy.py", line 85, in _run_code exec(code, run_globals) File "c:\users\mhurl\deepethogram\deepethogram\sequence\train.py", line 265, in sequence_train(cfg) File "c:\users\mhurl\deepethogram\deepethogram\sequence\train.py", line 75, in sequence_train trainer.fit(lightning_module) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\trainer.py", line 510, in fit results = self.accelerator_backend.train() File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\accelerators\accelerator.py", line 57, in train return self.train_or_test() File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\accelerators\accelerator.py", line 74, in train_or_test results = self.trainer.train() File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\trainer.py", line 561, in train self.train_loop.run_training_epoch() File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\training_loop.py", line 550, in run_training_epoch batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\training_loop.py", line 718, in run_training_batch self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\training_loop.py", line 493, in optimizer_step using_lbfgs=is_lbfgs, File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\core\lightning.py", line 1298, in optimizer_step optimizer.step(closure=optimizer_closure) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\core\optimizer.py", line 286, in step self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, kwargs) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\core\optimizer.py", line 144, in __optimizer_step optimizer.step(closure=closure, *args, *kwargs) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\torch\optim\optimizer.py", line 88, in wrapper return func(args, kwargs) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\torch\autograd\grad_mode.py", line 28, in decorate_context return func(*args, kwargs) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\torch\optim\adam.py", line 92, in step loss = closure() File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\training_loop.py", line 713, in train_step_and_backward_closure self.trainer.hiddens File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\training_loop.py", line 806, in training_step_and_backward result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\trainer\training_loop.py", line 319, in training_step training_step_output = self.trainer.accelerator_backend.training_step(args) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\accelerators\gpu_accelerator.py", line 70, in training_step return self._step(self.trainer.model.training_step, args) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\pytorch_lightning-1.1.8-py3.7.egg\pytorch_lightning\accelerators\gpu_accelerator.py", line 65, in _step output = model_step(args) File "c:\users\mhurl\deepethogram\deepethogram\sequence\train.py", line 138, in training_step return self.common_step(batch, batch_idx, 'train') File "c:\users\mhurl\deepethogram\deepethogram\sequence\train.py", line 118, in common_step loss, loss_dict = self.criterion(outputs, batch['labels'], self.model) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "c:\users\mhurl\deepethogram\deepethogram\feature_extractor\losses.py", line 207, in forward data_loss = self.data_criterion(outputs, label) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\torch\nn\modules\loss.py", line 1152, in forward label_smoothing=self.label_smoothing) File "C:\Users\mhurl\anaconda3\envs\deg\lib\site-packages\torch\nn\functional.py", line 2846, in cross_entropy return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) RuntimeError: Expected floating point type for target with class probabilities, got Long [2022-01-31 12:00:04,278] INFO [deepethogram.gui.main.sequence_train:496] Training finished. If you see error messages above, training did not complete successfully. [2022-01-31 12:00:04,278] INFO [deepethogram.gui.main.sequence_train:501] ~~~~~~

Thank you for the help!

jbohnslav commented 2 years ago

Are you using final_activation=='softmax'?

jbohnslav commented 2 years ago

Can you try to pip install --upgrade deepethogram and see if it's fixed?