Open CutieQing opened 5 months ago
This is expected. Torch xla either supports running on one device (nprocs = 1
) or all devices. If you pass anything other than nprocs = 1
, it will be ignored and fall back to running on all available devices.
Although I guess ideally, it shouldn't warn if nprocs == num devices
. https://pytorch.org/xla/release/1.6/index.html#torch_xla.distributed.xla_multiprocessing.spawn is a bit unclear on this
❓ Questions and Help
When I made multiprocess inference of huggingface transformers frame, I used xmp.spawn(perform_inference, args=(args,), nprocs=4), and I wanted to run 4 scripts once. However, it reported a warning that WARNING:root:Unsupported nprocs (4), ignoring... I wonder if it is a bug or it has any mistake in my infer script.
My infer script is as following:
if name == "main": torch.set_default_dtype(torch.float32) args = get_args()
detail log
WARNING:root:Unsupported nprocs (4), ignoring... WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.528224 2908632 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.528293 2908632 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.528300 2908632 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.544289 2908627 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.544426 2908627 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.544434 2908627 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.728254 2908631 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.728326 2908631 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.728332 2908631 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.916441 2908634 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.916616 2908634 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.916625 2908634 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.409535 2908636 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080893.409646 2908636 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.409654 2908636 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.658751 2908630 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080893.658883 2908630 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.658891 2908630 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.659256 2908635 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.659285 2908633 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080893.659431 2908635 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.659440 2908635 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. I0000 00:00:1720080893.659455 2908633 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.659465 2908633 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. port: 52003 tpu name: xla:0 sentences length: 16384 port: 40841 tpu name: xla:0 sentences length: 16384 port: 40729 tpu name: xla:0 sentences length: 16384 port: 51387 tpu name: xla:0 sentences length: 16384 port: 53707 tpu name: xla:0 sentences length: 16384 port: 45223 tpu name: xla:0 sentences length: 16384 port: 37585 tpu name: xla:0 sentences length: 16384 port: 36559 tpu name: xla:0 sentences length: 16384 inference time: 0.034876358113251626 inference time: 0.03664895799010992 inference time: 0.026097089052200317 inference time: 0.02792046801187098 inference time: 0.02882425906136632 inference time: 0.029096698039211333 inference time: 0.02789105800911784 inference time: 0.027401939034461975 inference time: 0.014182109967805445 inference time: 0.013394199078902602 inference time: 0.013075169990770519 inference time: 0.012977780075743794 inference time: 0.01341874001082033 ...