Sometimes (in models like llama), running the model creates new tensors like torch.ones(...) or something. This creates a cpu tensor by default which can mess with meta tensors during scanning. Need to set default device in _scan like torch.set_default_device or use patching if need be.
Sometimes (in models like llama), running the model creates new tensors like torch.ones(...) or something. This creates a cpu tensor by default which can mess with meta tensors during scanning. Need to set default device in _scan like torch.set_default_device or use patching if need be.