Open vldbnc opened 1 year ago
Having a similar problem with TemporalFusionTransformer
and QuantileLoss
. I can see that QuantileLoss._device
is set to cuda:0
even though the TFT model was mapped to CPU using
TemporalFusionTransformer.load_from_checkpoint(best_model_path, map_location='cpu')
as in OP's post.
Is that working:
model = TemporalFusionTransformer.load_from_checkpoint(best_model_path) predictions = model.predict(test).cpu()
?
No. I'm not sure why, but in the meantime, not even TemporalFusionTransformer.load_from_checkpoint(best_model_path)
works. It might be because I upgraded to Pytorch-Lightning 2.0.2.
It now throws an error when calling _load_from_checkpoint
in this line:
AssertionError: Torch not compiled with CUDA enabled
I am able to hack around this by adding the following between lines 89 and 90:
storage.loss._device = "cpu"
for metric in storage.logging_metrics:
metric._device = "cpu"
So manually patching the _device fields containing "cuda:0" to "cpu" in the model loaded from the checkpoint file helps to resolve this problem. But I think loading from a checkpoint should work regardless of whether the checkpoint was created on a GPU or not.
it might be worth mentioning there are these two warning in the log when trying to load a model from a checkpoint:
UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
Would ignoring these two attributes resolve the error we are seeing or is this unrelated?
Are there any updates on this?
I am able to hack around this by adding the following between lines 89 and 90:
This worked for me too. Thanks @jurgispods
lightning==2.0.1.post0 lightning-cloud==0.5.33 lightning-utilities==0.8.0 numpy==1.24.2 pandas==1.5.3 pyarrow==11.0.0 pytorch-forecasting==1.0.0 pytorch-lightning==2.0.1.post0 pytorch-optimizer==2.5.2 torch==2.0.0 torchmetrics==0.11.4
Expected behavior
changed correctly to 'cpu' but
device(type='cpu')
calculate mean absolute error on validation set