pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 483 forks source link

Multiprocess inference warning: ignoring nprocs #7633

Open CutieQing opened 5 months ago

CutieQing commented 5 months ago

❓ Questions and Help

When I made multiprocess inference of huggingface transformers frame, I used xmp.spawn(perform_inference, args=(args,), nprocs=4), and I wanted to run 4 scripts once. However, it reported a warning that WARNING:root:Unsupported nprocs (4), ignoring... I wonder if it is a bug or it has any mistake in my infer script.

My infer script is as following:

device = xm.xla_device()
print(f"tpu name: {device}")

sentences = ["Sample-1", "Sample-2"] * args.batch_size
print(f"sentences length: {len(sentences)}")

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = AutoModel.from_pretrained(args.model_name_or_path).to(device)
model.eval()

for i in range(20):
    if i == 19:
        print(f"log port: {port}")
        xp.trace_detached(f'localhost:{port}', './profiles/', duration_ms=2000)
    with xp.StepTrace('bge_test'):
        with xp.Trace('build_graph'):
            encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
            with torch.no_grad():
                start = time.perf_counter()
                model_output = model(**encoded_input)
                end = time.perf_counter()
                sentence_embeddings = model_output[0][:, 0]
                print("inference time:", (end - start))

sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
print("Sentence embeddings: ", sentence_embeddings)

if name == "main": torch.set_default_dtype(torch.float32) args = get_args()

xmp.spawn(perform_inference, args=(args,), nprocs=4)

detail log

WARNING:root:Unsupported nprocs (4), ignoring... WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.528224 2908632 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.528293 2908632 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.528300 2908632 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.544289 2908627 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.544426 2908627 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.544434 2908627 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.728254 2908631 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.728326 2908631 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.728332 2908631 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080892.916441 2908634 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080892.916616 2908634 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080892.916625 2908634 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.409535 2908636 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080893.409646 2908636 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.409654 2908636 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.658751 2908630 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080893.658883 2908630 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.658891 2908630 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.659256 2908635 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1720080893.659285 2908633 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/liqing002/.local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1720080893.659431 2908635 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.659440 2908635 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. I0000 00:00:1720080893.659455 2908633 pjrt_api.cc:79] PJRT_Api is set for device type tpu I0000 00:00:1720080893.659465 2908633 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The framework PJRT API version is 0.46. port: 52003 tpu name: xla:0 sentences length: 16384 port: 40841 tpu name: xla:0 sentences length: 16384 port: 40729 tpu name: xla:0 sentences length: 16384 port: 51387 tpu name: xla:0 sentences length: 16384 port: 53707 tpu name: xla:0 sentences length: 16384 port: 45223 tpu name: xla:0 sentences length: 16384 port: 37585 tpu name: xla:0 sentences length: 16384 port: 36559 tpu name: xla:0 sentences length: 16384 inference time: 0.034876358113251626 inference time: 0.03664895799010992 inference time: 0.026097089052200317 inference time: 0.02792046801187098 inference time: 0.02882425906136632 inference time: 0.029096698039211333 inference time: 0.02789105800911784 inference time: 0.027401939034461975 inference time: 0.014182109967805445 inference time: 0.013394199078902602 inference time: 0.013075169990770519 inference time: 0.012977780075743794 inference time: 0.01341874001082033 ...

BitPhinix commented 5 months ago

This is expected. Torch xla either supports running on one device (nprocs = 1) or all devices. If you pass anything other than nprocs = 1, it will be ignored and fall back to running on all available devices.

See https://github.com/pytorch/xla/blob/3bcb1fb17442e81cf15c956016908030937d5e89/torch_xla/_internal/pjrt.py#L209-L214

BitPhinix commented 5 months ago

Although I guess ideally, it shouldn't warn if nprocs == num devices. https://pytorch.org/xla/release/1.6/index.html#torch_xla.distributed.xla_multiprocessing.spawn is a bit unclear on this