Currently if we want to get the device of the parameters of a module, we do:
next(self.parameters()).device
However this is not possible if the model is FSDP-wrapped, since FSDP removes all parameters from self._parameters and next(self.parameters()).device will only get a StopIteration error.
The goal is to get the device in a way that whether the model is FSDP-wrapped, we can get the correct device.
Currently if we want to get the device of the parameters of a module, we do:
next(self.parameters()).device
However this is not possible if the model is FSDP-wrapped, since FSDP removes all parameters from
self._parameters
andnext(self.parameters()).device
will only get a StopIteration error.The goal is to get the device in a way that whether the model is FSDP-wrapped, we can get the correct device.