Open fancy45daddy opened 1 day ago
It is from
File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
91 if not using_pjrt():
92 raise NotImplementedError('`{}` not implemented for XRT'.format(
93 fn.__name__))
---> 95 return fn(*args, **kwargs)
which version of pytorch/xla you are using? I am wondering how did you still trigger the XRT runtime which has been deprecated for a long time.
import torch_xla torch_xla.version
2.4.0+libtpu
I just did a simple test on TPU v3-8 with the torch_xla 2.4.0 version:
import torch, torch_xla
import torch_xla.distributed.xla_multiprocessing as xmp
def process(index):
print(index)
xmp.spawn(process, start_method='fork')
You should see it prints out 0, 1, ...7, which means it uses 8 cores.
@zpcore than you, I see the code running on all tpu core on kaggle v3-8 now. But the new problem is after torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork') I just can see one process return successfully. All others are broken. But when I run single process it work. Is it possible to limit to just two cores. So I can debug fast.
In TPU, you can only use all the process with leaving the nproc
argument None or set nproc=1
. I don't think you can try with 2 process. I can't reproduce with the code you mentioned. Did you see any error message from the broken process?
@zpcore I figure out the problem, the StableDiffusionControlNetPipeline is very big and occupy many memory. If I use 8 process, I create 8 StableDiffusionControlNetPipeline pipeline. So that they eat up all the memory.
I have no idea how to reduce the memory, currently I use
import torch_xla, diffusers, builtins, imageio, os, PIL.Image, controlnet_aux, sys, torch
os.environ.pop('TPU_PROCESS_ADDRESSES')
reader = imageio.get_reader('/kaggle/input/controlnet/pose.mp4', 'ffmpeg')
openpose = controlnet_aux.DWposeDetector(det_config='yolox_l_8xb8-300e_coco.py', pose_config='dwpose-l_384x288.py')
poses = [openpose(PIL.Image.fromarray(reader.get_data(_)).resize((512, 768))) for _ in builtins.range(16)] #reader.count_frames()
length = builtins.len(poses) // 8
fps = reader.get_meta_data().get('fps')
controlnet = diffusers.ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_openpose', torch_dtype=torch.bfloat16)
pipeline = diffusers.StableDiffusionControlNetPipeline.from_single_file('https://huggingface.co/chaowenguo/pal/blob/main/chilloutMix-Ni.safetensors', config='chaowenguo/stable-diffusion-v1-5', safety_checker=None, controlnet=controlnet, use_safetensors=True, torch_dtype=torch.bfloat16)
pipeline.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.enable_attention_slicing()
pipeline.unet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
pipeline.controlnet.set_attn_processor(diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor())
def process(index):
pipe = diffusers.StableDiffusionControlNetPipeline(**pipeline.components)
pipe.to(torch_xla.core.xla_model.xla_device())
pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')
to test. But not working
30%|███ | 6/20 [00:03<00:09, 1.51it/s]
0%| | 0/20 [00:00<?, ?it/s]
0%| | 0/20 [00:00<?, ?it/s]
0%| | 0/20 [00:00<?, ?it/s]
35%|███▌ | 7/20 [00:04<00:11, 1.17it/s]
5%|▌ | 1/20 [00:01<00:22, 1.20s/it]
5%|▌ | 1/20 [00:01<00:22, 1.16s/it]
5%|▌ | 1/20 [00:01<00:23, 1.22s/it]
40%|████ | 8/20 [00:06<00:12, 1.02s/it]
10%|█ | 2/20 [00:02<00:24, 1.36s/it]
10%|█ | 2/20 [00:02<00:23, 1.29s/it]
10%|█ | 2/20 [00:02<00:23, 1.33s/it]
45%|████▌ | 9/20 [00:07<00:12, 1.14s/it]
15%|█▌ | 3/20 [00:04<00:23, 1.39s/it]
15%|█▌ | 3/20 [00:03<00:22, 1.35s/it]
15%|█▌ | 3/20 [00:04<00:23, 1.38s/it]
50%|█████ | 10/20 [00:09<00:12, 1.25s/it][A
20%|██ | 4/20 [00:05<00:22, 1.44s/it]
20%|██ | 4/20 [00:05<00:22, 1.40s/it]
20%|██ | 4/20 [00:05<00:23, 1.44s/it]
55%|█████▌ | 11/20 [00:10<00:11, 1.33s/it][A
25%|██▌ | 5/20 [00:07<00:22, 1.47s/it]
25%|██▌ | 5/20 [00:06<00:21, 1.44s/it]
25%|██▌ | 5/20 [00:07<00:22, 1.49s/it]
55%|█████▌ | 11/20 [12:32<10:15, 68.41s/it][A
55%|█████▌ | 11/20 [13:04<10:41, 71.28s/it]
25%|██▌ | 5/20 [13:57<41:53, 167.55s/it]
25%|██▌ | 5/20 [14:01<42:05, 168.35s/it]
55%|█████▌ | 11/20 [14:53<12:11, 81.24s/it]
55%|█████▌ | 11/20 [14:57<12:14, 81.60s/it]
25%|██▌ | 5/20 [16:05<48:17, 193.18s/it]
25%|██▌ | 5/20 [16:10<48:31, 194.09s/it]
---------------------------------------------------------------------------
_RemoteTraceback Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 78, in _run_thread_per_device
replica_results = list(
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
yield _result_or_cancel(fs.pop())
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
return fut.result(timeout)
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
return self.__get_result()
File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/usr/local/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 71, in _thread_fn
return fn()
File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 190, in __call__
self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
File "/tmp/ipykernel_13/28715152.py", line 20, in process
imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet.py", line 1282, in __call__
noise_pred = self.unet(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1281, in forward
sample = upsample_block(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 2551, in forward
hidden_states = attn(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 442, in forward
hidden_states = block(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/diffusers/models/attention.py", line 530, in forward
ff_output = self.ff(norm_hidden_states)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/diffusers/models/attention.py", line 1166, in forward
hidden_states = module(hidden_states)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:721 : Check failed: pjrt_device == pjrt_data->buffer->device()
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::runtime::PjRtComputationClient::ExecuteComputation(torch_xla::runtime::ComputationClient::Computation const&, absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>, std::string const&, torch_xla::runtime::ComputationClient::ExecuteComputationOptions const&)
torch::lazy::MultiWait::Complete(std::function<void ()> const&)
Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::WorkerLoop(int)
void absl::lts_20230802::internal_any_invocable::RemoteInvoker<false, void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(absl::lts_20230802::internal_any_invocable::TypeErasedState*)
__clone
*** End stack trace ***
TPU_1(process=0,(0,0,0,1)) vs TPU_0(process=0,(0,0,0,0))
"""
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Cell In[2], line 22
19 pose = sys.modules['__main__'].poses[index * sys.modules['__main__'].length:(index + 1) * sys.modules['__main__'].length]
20 imageio.mimsave(f'{index}.mp4', pipe(prompt=['gorgeous slim young cleavage robust boob japanese girl, wearing white deep V bandeau pantie, smile lying on white bed, best quality, extremely detailed'] * builtins.len(pose), negative_prompt=['monochrome, lowres, bad anatomy, worst quality, low quality'] * builtins.len(pose), image=pose, num_inference_steps=20, latents=torch.randn((1, 4, 96, 64), device=torch_xla.core.xla_model.xla_device(), dtype=torch.bfloat16).repeat(builtins.len(pose), 1, 1, 1)).images, fps=fps)
---> 22 torch_xla.distributed.xla_multiprocessing.spawn(process, start_method='fork')
File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
91 if not using_pjrt():
92 raise NotImplementedError('`{}` not implemented for XRT'.format(
93 fn.__name__))
---> 95 return fn(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py:38, in spawn(fn, args, nprocs, join, daemon, start_method)
6 @xr.requires_pjrt
7 def spawn(fn,
8 args=(),
(...)
11 daemon=False,
12 start_method='spawn'):
13 """Enables multi processing based replication.
14
15 Args:
(...)
36 return None.
37 """
---> 38 return pjrt.spawn(fn, nprocs, start_method, args)
File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:214, in spawn(fn, nprocs, start_method, args)
211 elif nprocs is not None:
212 logging.warning('Unsupported nprocs (%d), ignoring...' % nprocs)
--> 214 run_multiprocess(spawn_fn, start_method=start_method)
File /usr/local/lib/python3.10/site-packages/torch_xla/runtime.py:95, in requires_pjrt.<locals>.wrapper(*args, **kwargs)
91 if not using_pjrt():
92 raise NotImplementedError('`{}` not implemented for XRT'.format(
93 fn.__name__))
---> 95 return fn(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:174, in run_multiprocess(fn, start_method, *args, **kwargs)
168 mp_fn = functools.partial(
169 _run_thread_per_device,
170 local_world_size=num_processes,
171 fn=functools.partial(fn, *args, **kwargs),
172 initializer_fn=initialize_multiprocess)
173 process_results = executor.map(mp_fn, range(num_processes))
--> 174 replica_results = list(
175 itertools.chain.from_iterable(
176 result.items() for result in process_results))
178 return _merge_replica_results(replica_results)
File /usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py:175, in <genexpr>(.0)
168 mp_fn = functools.partial(
169 _run_thread_per_device,
170 local_world_size=num_processes,
171 fn=functools.partial(fn, *args, **kwargs),
172 initializer_fn=initialize_multiprocess)
173 process_results = executor.map(mp_fn, range(num_processes))
174 replica_results = list(
--> 175 itertools.chain.from_iterable(
176 result.items() for result in process_results))
178 return _merge_replica_results(replica_results)
File /usr/local/lib/python3.10/concurrent/futures/process.py:575, in _chain_from_iterable_of_lists(iterable)
569 def _chain_from_iterable_of_lists(iterable):
570 """
571 Specialized implementation of itertools.chain.from_iterable.
572 Each item in *iterable* should be a list. This function is
573 careful not to keep references to yielded objects.
574 """
--> 575 for element in iterable:
576 element.reverse()
577 while element:
File /usr/local/lib/python3.10/concurrent/futures/_base.py:621, in Executor.map.<locals>.result_iterator()
618 while fs:
619 # Careful not to keep a reference to the popped future
620 if timeout is None:
--> 621 yield _result_or_cancel(fs.pop())
622 else:
623 yield _result_or_cancel(fs.pop(), end_time - time.monotonic())
File /usr/local/lib/python3.10/concurrent/futures/_base.py:319, in _result_or_cancel(***failed resolving arguments***)
317 try:
318 try:
--> 319 return fut.result(timeout)
320 finally:
321 fut.cancel()
File /usr/local/lib/python3.10/concurrent/futures/_base.py:458, in Future.result(self, timeout)
456 raise CancelledError()
457 elif self._state == FINISHED:
--> 458 return self.__get_result()
459 else:
460 raise TimeoutError()
File /usr/local/lib/python3.10/concurrent/futures/_base.py:403, in Future.__get_result(self)
401 if self._exception:
402 try:
--> 403 raise self._exception
404 finally:
405 # Break a reference cycle with the exception in self._exception
406 self = None
RuntimeError: torch_xla/csrc/runtime/pjrt_computation_client.cc:721 : Check failed: pjrt_device == pjrt_data->buffer->device()
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::runtime::PjRtComputationClient::ExecuteComputation(torch_xla::runtime::ComputationClient::Computation const&, absl::lts_20230802::Span<std::shared_ptr<torch_xla::runtime::ComputationClient::Data> const>, std::string const&, torch_xla::runtime::ComputationClient::ExecuteComputationOptions const&)
torch::lazy::MultiWait::Complete(std::function<void ()> const&)
Eigen::ThreadPoolTempl<tsl::thread::EigenEnvironment>::WorkerLoop(int)
void absl::lts_20230802::internal_any_invocable::RemoteInvoker<false, void, tsl::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}&>(absl::lts_20230802::internal_any_invocable::TypeErasedState*)
__clone
*** End stack trace ***
TPU_1(process=0,(0,0,0,1)) vs TPU_0(process=0,(0,0,0,0))
Do you have any idea how to reduce the memory?
❓ Questions and Help
I want to run pytorch xla on kaggle tpu v3-8 and use all core in tpu. But I always get A process in the process pool was terminated abruptly while the future was running or pending.
Source code:
and get
Single process on single tpu core works well but can not work multiple process and use all cores on tpu.
Please help.
You can copy and test my code in https://www.kaggle.com/code/chaowenguoback/stablediffusion. Please help