vocalpy / vak

A neural network framework for researchers studying acoustic communication
https://vak.readthedocs.io
BSD 3-Clause "New" or "Revised" License
77 stars 16 forks source link

BUG: `Cannot convert a MPS Tensor to float64 dtype` on Apple M1 Max #700

Closed NickleDave closed 1 year ago

NickleDave commented 1 year ago

Before submitting a bug, please make sure the issue hasn't been already addressed by searching through the past issues

Describe the bug This is a bug reported by @VenetianRed in the vocalpy forum here: https://forum.vocalpy.org/t/vak-tweetynet-with-an-apple-m1-max/78

I’m trying to work through the introductory example vak to learn to classify notes in the songs of Java sparrows and am running into trouble on a laptop with an Apple M1 Max chip. I’m able to go ‘vak prep …’, but when I go to train my neural net, I get

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn’t support float64. Please use float32 instead.

(TweetyNet) mark@ultramarine:Practice > vak train gy6or6_train.toml
2023-09-15 13:11:50,211 - vak.cli.train - INFO - vak version: 1.0.0a1
2023-09-15 13:11:50,211 - vak.cli.train - INFO - Logging results to gy6or6/vak/train/results/results_230915_131150
2023-09-15 13:11:50,212 - vak.core.train - INFO - Loading dataset from .csv path: gy6or6/vak/prep/train/032212_prep_230915_131124.csv
2023-09-15 13:11:50,214 - vak.core.train - INFO - Size of timebin in spectrograms from dataset, in seconds: 0.002
2023-09-15 13:11:50,214 - vak.core.train - INFO - using training dataset from gy6or6/vak/prep/train/032212_prep_230915_131124.csv
2023-09-15 13:11:50,214 - vak.core.train - INFO - Total duration of training split from dataset (in s): 57.17199999999999
2023-09-15 13:11:50,362 - vak.core.train - INFO - number of classes in labelmap: 12
2023-09-15 13:11:50,362 - vak.core.train - INFO - no spect_scaler_path provided, not loading
2023-09-15 13:11:50,362 - vak.core.train - INFO - will normalize spectrograms
2023-09-15 13:11:50,405 - vak.core.train - INFO - Duration of WindowDataset used for training, in seconds: 57.172000000000004
2023-09-15 13:11:50,419 - vak.core.train - INFO - Total duration of validation split from dataset (in s): 21.266
2023-09-15 13:11:50,419 - vak.core.train - INFO - will measure error on validation set every 400 steps of training
2023-09-15 13:11:50,426 - vak.core.train - INFO - training TweetyNet
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2023-09-15 13:11:50,473 - vak.core.train - INFO - Training start time: 2023-09-15T13:11:50.473669
Missing logger folder: /Users/mark/Current_projects/Anthony Kwong/TweetyNet/Practice/gy6or6/vak/train/results/results_230915_131150/TweetyNet/lightning_logs

  | Name    | Type             | Params
---------------------------------------------
0 | network | TweetyNet        | 1.1 M
1 | loss    | CrossEntropyLoss | 0
---------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.444     Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s]/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Sanity Checking DataLoader 0:   0%|                                                                                                                  | 0/2 [00:00<?, ?it/s]/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:212: UserWarning: You called `self.log('val_levenshtein', ...)` in your `validation_step` but the value needs to be floating point. Converting it to torch.float32.
  warning_cache.warn(
Traceback (most recent call last):
  File "/opt/anaconda3/envs/TweetyNet/bin/vak", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/vak/__main__.py", line 48, in main
    cli.cli(command=args.command, config_file=args.configfile)
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/vak/cli/cli.py", line 49, in cli
    COMMAND_FUNCTION_MAP[command](toml_path=config_file)
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/vak/cli/cli.py", line 8, in train
    train(toml_path=toml_path)
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/vak/cli/train.py", line 67, in train
    core.train(
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/vak/core/train.py", line 369, in train
    trainer.fit(
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in fit
    call._call_and_handle_interrupt(
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 571, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 980, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1021, in _run_stage
    self._run_sanity_check()
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1050, in _run_sanity_check
    val_loop.run()
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 376, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 294, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 393, in validation_step
    return self.model.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/vak/models/windowed_frame_classification_model.py", line 208, in validation_step
    self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1)
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 447, in log
    value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 51, in apply_to_collection
    return function(data, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 619, in __to_tensor
    else torch.tensor(value, device=self.device)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Environment file attached: condaList.txt

NickleDave commented 1 year ago

Copying my reply from the forum: https://forum.vocalpy.org/t/vak-tweetynet-with-an-apple-m1-max/78/11?u=nicholdav

I think the problem isn't the way we save the arrays for the dataset. I confirmed this shouldn't be an issue but I'll spare you the details (basically, we transform inputs to float32 when we load them).

(I first thought it was because we were loading spectrograms as float64, but in fact we apply a transform to make them float32)

My best guess is that the error happens when lightning tries to take a computed metric value that is returned and put it in a tensor, in order to log that value. Because the returned value is float64, we get this error.

From the traceback you provided, we see where we're calling self.log:

 File "/opt/anaconda3/envs/TweetyNet/lib/python3.11/site-packages/vak/models/windowed_frame_classification_model.py", line 208, in validation_step
    self.log(f'val_{metric_name}', metric_callable(y_pred_labels, y_labels), batch_size=1)

Here's line 208 in version 1.0.0a1 that you're using: https://github.com/vocalpy/vak/blob/3dcce70030ae9b1fd6d040e055def0d656a7512e/src/vak/models/windowed_frame_classification_model.py#L208

You can see that we're computing an edit distance metric using string labels. So it can't be the tensor inputs to the model, and it has to be the returned value. I put in a breakpoint() before that line and ran an eval file to confirm that, yes, the segment_edit_distance returns a numpy float with dtype float64. I'm guessing that's what causes the crash.

I think I have a fix here: https://github.com/vocalpy/vak/tree/make-distance-metrics-return-tensors

I will raise a separate issue stating we need all metrics to return tensors, apply the fix I have in progress in that branch, and then release a new version

NickleDave commented 1 year ago

@all-contributors please add @VenetianRed for bug

allcontributors[bot] commented 1 year ago

@NickleDave

I've put up a pull request to add @VenetianRed! :tada: