Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.38k forks source link

RuntimeError: index_add(): Expected non int64 dtype for source. #17683

Closed erl61 closed 1 year ago

erl61 commented 1 year ago

Bug description

I am trying to run "Demand forecasting with the Temporal Fusion Transformer" from PyTorch-Forecasting tutorial. It works perfect with accelerator="cpu". But when I change it to accelerator="mps" it shows "RuntimeError: index_add(): Expected non int64 dtype for source." The error happens when I call lightning.pytorch.Trainer. I have MacBook M1 Pro and tried different versions of Python, PyTorch, PyTorch-Forecasting and PyTorch-Lightning. torch.backends.mps.is_available() shows "True".

What version are you seeing the problem on?

v2.0, master

How to reproduce the bug

%env PYTORCH_ENABLE_MPS_FALLBACK=1

import os
import warnings
warnings.filterwarnings("ignore")  # avoid printing out absolute paths
import copy
from pathlib import Path

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import pandas as pd
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.data.examples import get_stallion_data

data = get_stallion_data()

# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()

# add additional features
data["month"] = data.date.dt.month.astype(str).astype("category")  # categories have be strings
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")

# we want to encode special days as one variable and thus need to first reverse one-hot encoding
special_days = [
    "easter_day",
    "good_friday",
    "new_year",
    "christmas",
    "labor_day",
    "independence_day",
    "revolution_day_memorial",
    "regional_games",
    "fifa_u_17_world_cup",
    "football_gold_cup",
    "beer_capital",
    "music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")

max_prediction_length = 6
max_encoder_length = 24
training_cutoff = data["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="volume",
    group_ids=["agency", "sku"],
    min_encoder_length=max_encoder_length // 2,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["agency", "sku"],
    static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
    time_varying_known_categoricals=["special_days", "month"],
    variable_groups={"special_days": special_days},  # group of categorical variables can be treated as one variable
    time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "volume",
        "log_volume",
        "industry_volume",
        "soda_volume",
        "avg_max_temp",
        "avg_volume_by_agency",
        "avg_volume_by_sku",
    ],
    target_normalizer=GroupNormalizer(
        groups=["agency", "sku"], transformation="softplus"
    ),  # use softplus and normalize by group
    add_relative_time_idx=False,
    add_target_scales=False,
    add_encoder_length=False,
)

# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)

# create dataloaders for model
batch_size = 16  # set this between 32 to 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)

# configure network and trainer
pl.seed_everything(42)
trainer = pl.Trainer(
    accelerator="mps",
    # clipping gradients is a hyperparameter and important to prevent divergance
    # of the gradient for recurrent neural networks
    gradient_clip_val=0.1,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=0.03,
    hidden_size=8,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=1,
    dropout=0.1,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=8,  # set to <= hidden_size
    loss=QuantileLoss(),
    optimizer="Ranger"
    # reduce learning rate if no improvement in validation loss after x epochs
    # reduce_on_plateau_patience=1000,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find optimal learning rate
from lightning.pytorch.tuner import Tuner

res = Tuner(trainer).lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

Error messages and logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 125
    122 # find optimal learning rate
    123 from lightning.pytorch.tuner import Tuner
--> 125 res = Tuner(trainer).lr_find(
    126     tft,
    127     train_dataloaders=train_dataloader,
    128     val_dataloaders=val_dataloader,
    129     max_lr=10.0,
    130     min_lr=1e-6,
    131 )
    133 print(f"suggested learning rate: {res.suggestion()}")
    134 fig = res.plot(show=True, suggest=True)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/tuner/tuning.py:175](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/tuner/tuning.py:175), in Tuner.lr_find(self, model, train_dataloaders, val_dataloaders, dataloaders, datamodule, method, min_lr, max_lr, num_training, mode, early_stop_threshold, update_attr, attr_name)
    172 lr_finder_callback._early_exit = True
    173 self._trainer.callbacks = [lr_finder_callback] + self._trainer.callbacks
--> 175 self._trainer.fit(model, train_dataloaders, val_dataloaders, datamodule)
    177 self._trainer.callbacks = [cb for cb in self._trainer.callbacks if cb is not lr_finder_callback]
    179 return lr_finder_callback.optimal_lr

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/trainer.py:520](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/trainer.py:520), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    518 model = _maybe_unwrap_optimized(model)
    519 self.strategy._lightning_module = model
--> 520 call._call_and_handle_interrupt(
    521     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    522 )

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:44](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:44), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     43     else:
---> 44         return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/trainer.py:559](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/trainer.py:559), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    549 self._data_connector.attach_data(
    550     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    551 )
    553 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    554     self.state.fn,
    555     ckpt_path,
    556     model_provided=True,
    557     model_connected=self.lightning_module is not None,
    558 )
--> 559 self._run(model, ckpt_path=ckpt_path)
    561 assert self.state.stopped
    562 self.training = False

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/trainer.py:915](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/trainer.py:915), in Trainer._run(self, model, ckpt_path)
    913 # hook
    914 if self.state.fn == TrainerFn.FITTING:
--> 915     call._call_callback_hooks(self, "on_fit_start")
    916     call._call_lightning_module_hook(self, "on_fit_start")
    918 _log_hyperparams(self)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:190](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:190), in _call_callback_hooks(trainer, hook_name, monitoring_callbacks, *args, **kwargs)
    188     if callable(fn):
    189         with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
--> 190             fn(trainer, trainer.lightning_module, *args, **kwargs)
    192 if pl_module:
    193     # restore current_fx when nested context
    194     pl_module._current_fx_name = prev_fx_name

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/callbacks/lr_finder.py:125](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/callbacks/lr_finder.py:125), in LearningRateFinder.on_fit_start(self, trainer, pl_module)
    124 def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
--> 125     self.lr_find(trainer, pl_module)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/callbacks/lr_finder.py:109](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/callbacks/lr_finder.py:109), in LearningRateFinder.lr_find(self, trainer, pl_module)
    107 def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
    108     with isolate_rng():
--> 109         self.optimal_lr = _lr_find(
    110             trainer,
    111             pl_module,
    112             min_lr=self._min_lr,
    113             max_lr=self._max_lr,
    114             num_training=self._num_training_steps,
    115             mode=self._mode,
    116             early_stop_threshold=self._early_stop_threshold,
    117             update_attr=self._update_attr,
    118             attr_name=self._attr_name,
    119         )
    121     if self._early_exit:
    122         raise _TunerExitException()

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/tuner/lr_finder.py:269](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/tuner/lr_finder.py:269), in _lr_find(trainer, model, min_lr, max_lr, num_training, mode, early_stop_threshold, update_attr, attr_name)
    266 lr_finder._exchange_scheduler(trainer)
    268 # Fit, lr & loss logged in callback
--> 269 _try_loop_run(trainer, params)
    271 # Prompt if we stopped early
    272 if trainer.global_step != num_training + start_steps:

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/tuner/lr_finder.py:495](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/tuner/lr_finder.py:495), in _try_loop_run(trainer, params)
    493 loop.load_state_dict(deepcopy(params["loop_state_dict"]))
    494 loop.restarting = False
--> 495 loop.run()

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/fit_loop.py:201](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/fit_loop.py:201), in _FitLoop.run(self)
    199 try:
    200     self.on_advance_start()
--> 201     self.advance()
    202     self.on_advance_end()
    203     self._restarting = False

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/fit_loop.py:354](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/fit_loop.py:354), in _FitLoop.advance(self)
    352 self._data_fetcher.setup(combined_loader)
    353 with self.trainer.profiler.profile("run_training_epoch"):
--> 354     self.epoch_loop.run(self._data_fetcher)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/training_epoch_loop.py:133](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/training_epoch_loop.py:133), in _TrainingEpochLoop.run(self, data_fetcher)
    131 while not self.done:
    132     try:
--> 133         self.advance(data_fetcher)
    134         self.on_advance_end()
    135         self._restarting = False

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/training_epoch_loop.py:218](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/training_epoch_loop.py:218), in _TrainingEpochLoop.advance(self, data_fetcher)
    215 with trainer.profiler.profile("run_training_batch"):
    216     if trainer.lightning_module.automatic_optimization:
    217         # in automatic optimization, there can only be one optimizer
--> 218         batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
    219     else:
    220         batch_output = self.manual_optimization.run(kwargs)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:185](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:185), in _AutomaticOptimization.run(self, optimizer, kwargs)
    178         closure()
    180 # ------------------------------
    181 # BACKWARD PASS
    182 # ------------------------------
    183 # gradient update with accumulated gradients
    184 else:
--> 185     self._optimizer_step(kwargs.get("batch_idx", 0), closure)
    187 result = closure.consume_result()
    188 if result.loss is None:

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:261](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:261), in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    258     self.optim_progress.optimizer.step.increment_ready()
    260 # model hook
--> 261 call._call_lightning_module_hook(
    262     trainer,
    263     "optimizer_step",
    264     trainer.current_epoch,
    265     batch_idx,
    266     optimizer,
    267     train_step_and_backward_closure,
    268 )
    270 if not should_accumulate:
    271     self.optim_progress.optimizer.step.increment_completed()

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:142](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:142), in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    139 pl_module._current_fx_name = hook_name
    141 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 142     output = fn(*args, **kwargs)
    144 # restore current_fx when nested context
    145 pl_module._current_fx_name = prev_fx_name

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/core/module.py:1265](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/core/module.py:1265), in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1226 def optimizer_step(
   1227     self,
   1228     epoch: int,
   (...)
   1231     optimizer_closure: Optional[Callable[[], Any]] = None,
   1232 ) -> None:
   1233     r"""
   1234     Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1235     the optimizer.
   (...)
   1263                     pg["lr"] = lr_scale * self.learning_rate
   1264     """
-> 1265     optimizer.step(closure=optimizer_closure)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/core/optimizer.py:158](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/core/optimizer.py:158), in LightningOptimizer.step(self, closure, **kwargs)
    155     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    157 assert self._strategy is not None
--> 158 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    160 self._on_after_step()
    162 return step_output

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/strategies/strategy.py:224](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/strategies/strategy.py:224), in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    222 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    223 assert isinstance(model, pl.LightningModule)
--> 224 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:114](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:114), in PrecisionPlugin.optimizer_step(self, optimizer, model, closure, **kwargs)
    112 """Hook to run the optimizer step."""
    113 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 114 return optimizer.step(closure=closure, **kwargs)

File [~/Library/Python/3.9/lib/python/site-packages/torch/optim/lr_scheduler.py:69](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/torch/optim/lr_scheduler.py:69), in LRScheduler.__init__..with_counter..wrapper(*args, **kwargs)
     67 instance._step_count += 1
     68 wrapped = func.__get__(instance, cls)
---> 69 return wrapped(*args, **kwargs)

File [~/Library/Python/3.9/lib/python/site-packages/torch/optim/optimizer.py:280](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/torch/optim/optimizer.py:280), in Optimizer.profile_hook_step..wrapper(*args, **kwargs)
    276         else:
    277             raise RuntimeError(f"{func} must return None or a tuple of (new_args, new_kwargs),"
    278                                f"but got {result}.")
--> 280 out = func(*args, **kwargs)
    281 self._optimizer_step_code()
    283 # call optimizer step post hooks

File [~/Library/Python/3.9/lib/python/site-packages/torch/utils/_contextlib.py:115](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/torch/utils/_contextlib.py:115), in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/Library/Python/3.9/lib/python/site-packages/pytorch_optimizer/optimizer/ranger.py:99](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/pytorch_optimizer/optimizer/ranger.py:99), in Ranger.step(self, closure)
     97 if closure is not None:
     98     with torch.enable_grad():
---> 99         loss = closure()
    101 for group in self.param_groups:
    102     if 'step' in group:

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:101](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:101), in PrecisionPlugin._wrap_closure(self, model, optimizer, closure)
     89 def _wrap_closure(
     90     self,
     91     model: "pl.LightningModule",
     92     optimizer: Optimizer,
     93     closure: Callable[[], Any],
     94 ) -> Any:
     95     """This double-closure allows makes sure the ``closure`` is executed before the
     96     ``on_before_optimizer_step`` hook is called.
     97 
     98     The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
     99     consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    100     """
--> 101     closure_result = closure()
    102     self._after_closure(model, optimizer)
    103     return closure_result

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:140](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:140), in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:135](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:135), in Closure.closure(self, *args, **kwargs)
    132     self._zero_grad_fn()
    134 if self._backward_fn is not None and step_output.closure_loss is not None:
--> 135     self._backward_fn(step_output.closure_loss)
    137 return step_output

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:233](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/loops/optimization/automatic.py:233), in _AutomaticOptimization._make_backward_fn..backward_fn(loss)
    232 def backward_fn(loss: Tensor) -> None:
--> 233     call._call_strategy_hook(self.trainer, "backward", loss, optimizer)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:288](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/trainer/call.py:288), in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    285     return
    287 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 288     output = fn(*args, **kwargs)
    290 # restore current_fx when nested context
    291 pl_module._current_fx_name = prev_fx_name

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/strategies/strategy.py:199](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/strategies/strategy.py:199), in Strategy.backward(self, closure_loss, optimizer, *args, **kwargs)
    196 assert self.lightning_module is not None
    197 closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module)
--> 199 self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
    201 closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module)
    202 self.post_backward(closure_loss)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:67](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py:67), in PrecisionPlugin.backward(self, tensor, model, optimizer, *args, **kwargs)
     49 def backward(  # type: ignore[override]
     50     self,
     51     tensor: Tensor,
   (...)
     55     **kwargs: Any,
     56 ) -> None:
     57     r"""Performs the actual backpropagation.
     58 
     59     Args:
   (...)
     65         \**kwargs: Keyword arguments for the same purpose as ``*args``.
     66     """
---> 67     model.backward(tensor, *args, **kwargs)

File [~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/core/module.py:1054](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/lightning/pytorch/core/module.py:1054), in LightningModule.backward(self, loss, *args, **kwargs)
   1052     self._fabric.backward(loss, *args, **kwargs)
   1053 else:
-> 1054     loss.backward(*args, **kwargs)

File [~/Library/Python/3.9/lib/python/site-packages/torch/_tensor.py:487](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/torch/_tensor.py:487), in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File [~/Library/Python/3.9/lib/python/site-packages/torch/autograd/__init__.py:200](https://file+.vscode-resource.vscode-cdn.net/Users/alekseivv/~/Library/Python/3.9/lib/python/site-packages/torch/autograd/__init__.py:200), in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: index_add(): Expected non int64 dtype for source.

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.0.2 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.2 - pytorch-optimizer: 2.9.1 - torch: 2.0.1 - torchmetrics: 0.11.4 * Packages: - aiohttp: 3.8.4 - aiosignal: 1.3.1 - alembic: 1.11.1 - altgraph: 0.17.2 - anyio: 3.6.2 - appnope: 0.1.3 - arrow: 1.2.3 - asttokens: 2.2.1 - async-timeout: 4.0.2 - attrs: 23.1.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - certifi: 2023.5.7 - charset-normalizer: 3.1.0 - click: 8.1.3 - cmaes: 0.9.1 - colorlog: 6.7.0 - comm: 0.1.3 - contourpy: 1.0.7 - cramjam: 2.6.2 - croniter: 1.3.14 - cycler: 0.11.0 - datasets: 2.12.0 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - dill: 0.3.6 - executing: 1.2.0 - fastapi: 0.88.0 - fastparquet: 2023.4.0 - filelock: 3.12.0 - fonttools: 4.39.4 - frozenlist: 1.3.3 - fsspec: 2023.5.0 - future: 0.18.2 - h11: 0.14.0 - huggingface-hub: 0.14.1 - idna: 3.4 - importlib-metadata: 6.6.0 - importlib-resources: 5.12.0 - inquirer: 3.1.3 - ipykernel: 6.23.1 - ipython: 8.13.2 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - joblib: 1.2.0 - jupyter-client: 8.2.0 - jupyter-core: 5.3.0 - kiwisolver: 1.4.4 - lightning: 2.0.2 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - macholib: 1.15.2 - mako: 1.2.4 - markdown-it-py: 2.2.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.4 - multiprocess: 0.70.14 - nest-asyncio: 1.5.6 - networkx: 3.1 - numpy: 1.24.3 - optuna: 3.1.1 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.1 - parso: 0.8.3 - patsy: 0.5.3 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.1.2 - platformdirs: 3.5.1 - prompt-toolkit: 3.0.38 - psutil: 5.9.5 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 12.0.0 - pydantic: 1.10.7 - pygments: 2.15.1 - pyjwt: 2.7.0 - pyparsing: 3.0.9 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-forecasting: 1.0.0 - pytorch-lightning: 2.0.2 - pytorch-optimizer: 2.9.1 - pytz: 2023.3 - pyyaml: 6.0 - pyzmq: 25.0.2 - readchar: 4.0.5 - regex: 2023.5.5 - requests: 2.31.0 - responses: 0.18.0 - rich: 13.3.5 - scikit-learn: 1.2.2 - scipy: 1.10.1 - setuptools: 58.0.4 - six: 1.15.0 - sniffio: 1.3.0 - soupsieve: 2.4.1 - sqlalchemy: 2.0.15 - stack-data: 0.6.2 - starlette: 0.22.0 - starsessions: 1.3.0 - statsmodels: 0.14.0 - sympy: 1.12 - threadpoolctl: 3.1.0 - tokenizers: 0.13.3 - torch: 2.0.1 - torchmetrics: 0.11.4 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.29.2 - typing-extensions: 4.5.0 - tzdata: 2023.3 - urllib3: 1.26.6 - uvicorn: 0.22.0 - wcwidth: 0.2.6 - websocket-client: 1.5.2 - websockets: 11.0.3 - wheel: 0.37.0 - xxhash: 3.2.0 - yarl: 1.9.2 - zipp: 3.15.0 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.9.6 - release: 22.4.0 - version: Darwin Kernel Version 22.4.0: Mon Mar 6 20:59:28 PST 2023; root:xnu-8796.101.5~3[/RELEASE_ARM64_T6000](https://file+.vscode-resource.vscode-cdn.net/RELEASE_ARM64_T6000)

More info

No response

cc @justusschock

awaelchli commented 1 year ago

@erl61 Thanks for reporting. The MPS backend in PyTorch doesn't support an int64 tensor as an input to the index_add(). This happens during backward, and there is a layer in your model that probably used int64 as the data type somewhere. I suggest that you report it to PyTorch directly. I don't see how Lightning could do anything about that. The MPS backend does not support all operations and data types (yet).