Closed YichengDWu closed 10 months ago
Hi @YichengDWu
The dataloader_iter
feature is undocumented and experimental :) We expose the iterator this way so that the user has full control over how the batch is fetched. The responsibility of moving the batch to the right device is on the user. You can achieve this by doing batch.to(self.device)
(if it's a tensor).
Does that sound good @YichengDWu?
Thank you for your explanation.
I found it in the upgrade guide in the documentation. If I understand correctly, the new API seems to be the officially recommended drop-in replacement. However, in practice, as you explained, the old API automatically moves the batch to the device, while the new one does not, creating an inconsistency. If this is not a bug, should it be documented?
So far, the only user we knew of using this feature was NeMo, and we've made changes by discussing it with them. And afaik in their use case, it is undesirable if Lightning makes the decision of how to move the batch. Therefore, we leave this up to the user, or in the case of NeMo with Megatron, Megatron will fetch the micro-batch using the dataloader_iter and move it to the right device according to the pipeline parallelism.
I'm sorry this has lead to an uncomfortable change for you in 2.1, but we held back on documenting this niche feature precisely because we wanted to allow ourselves to make changes as we see fit. We also had plans to incorporate a Megatron-like strategy in Lightning, and for this we would need to further explore whether the dataloader_iter
design is in an acceptable state before we can document it.
However, we could in theory start documenting this feature with a warning that it is experimental.
You point is clear and I actually agree. Personally, I don't an issue with it now, feel free to close it if you see fit :).
Bug description
I'm following the tutorial and the old version runs smoothly with lightning 2.1. However, when I try to upgrade to the new interface
training_step(dataloader_iter)
I got an error. It looks like the batch still lives on cpu. Am I doing anything wrong?What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```More info
No response
cc @borda @justusschock @awaelchli