… Also the model might be a 'meta' pre-dispatched version of the model.
# That means the tensors as args and the model are different devices but we dont want to have to have the users move tensors to 'meta'
# So only when theres a FakeTensor with device meta, we move other tensors also to meta.
The error is that youre using tensors on device cuda:0, but the scanned proxy values underlying the first trace is on device 'meta' because you prob dont have dispatch=True? meaning the model is only loaded after the first trace and then has cuda:0 as its parameters so then the error does not happen.
So when I prepare the values before "validating" each operations you do in the tracing context, I'll check for FakeTensors, if theres one on device meta, I'll move a copy of any tensors that interact with them also to meta
… Also the model might be a 'meta' pre-dispatched version of the model.
The error is that youre using tensors on device cuda:0, but the scanned proxy values underlying the first trace is on device 'meta' because you prob dont have dispatch=True? meaning the model is only loaded after the first trace and then has cuda:0 as its parameters so then the error does not happen.
So when I prepare the values before "validating" each operations you do in the tracing context, I'll check for FakeTensors, if theres one on device meta, I'll move a copy of any tensors that interact with them also to meta