Open ManuelFay opened 1 month ago
All right, I am managing to dial down the problem. Not there completely but reporting so that no redundancy emerges. Problem seems to be in the image_preprocessing in multi-gpu mode.
Basically, images of various sizes get processed per device batch, and concatenated.
Tyypically, pixel_values per image may look like this and get concatenated automatically in the processor by batch:
(Single GPU, bs=4)
(1968, 1176)
(2040, 1176)
(1900, 1176)
(1900, 1176) --> flatten_patches.shape in _preprocess function
torch.Size([7808, 1176]) ---> print(image_inputs["pixel_values"].shape) in the preprocessor __call__ function
(2000, 1176)
(1976, 1176)
(1944, 1176)
(1900, 1176)
torch.Size([7820, 1176])
In multi-GPU setup, concatenation is done across GPUs
2-GPU, bs=4
(1968, 1176)
(2040, 1176)
(1900, 1176)
(1900, 1176)
(2000, 1176)
(1976, 1176)
(1944, 1176)
(1900, 1176)
torch.Size([15628, 1176])
Notice images are the same, they are just grouped together differently.
Where it gets interesting is that I assume once this "batched" preprocessing gets resplit across GPUs, it gets split in half (and not how it originally was !)
The size of tensor a (7814) must match the size of tensor b (7808) at non-singleton dimension 1
Notice 7814 = (7808 + 7820) / 2 = 15628 / 2
Now it's just a matter of finding where this split is done (data parallel model) to correct this.
Ok seems it's the:
torch/nn/parallel/data_parallel.py
scatter function that is supposed to:
Slice tensors into approximately equal chunks and distributes them across given GPUs.
Stay tuned for the next episode
cc @zucchini-nlp as well
Okay, got to the bottom of it.
Basically, as I said, when training with the HF trainer (I am running contrastive training tasks where it's nice not to have independant per-device batches), the pixel_values get concatenated during a gather operation, but since they are all of different shapes, they cannot be scattered correctly by the DataParallel scatter wrapper.
It would be possible to specify a chunk size in torch scatter but doing this would necessitate going deep into torch code cuz this param is set to None by default.
The fix I implemented on the Qwen side, is not to concat the pixel values but instead stack them with 0 padding. At the very end, you can reverse the operation, remove the padding and concat everything.
It's way slower but it works.
This problem will arise only in very specific use cases in which images are different dimensions are fed to a trainer that needs to gather and scatter for DDP.
However, it raises the question of whether concatenating values of different sizes and having an offset index is a "robust" way of doing things for VLM. It here fails to consider my usecase and it feels very anti-pattern/hacky with respect to the classic sequence length padding we usually do in NLP.
I'll leave the issue cuz the problem is not fixed but I would understand it's a bit too niche to warrant attention. This could change however if contrastive VLMs used for embeddings get truly popular !
@ManuelFay WOW, very detailed investigation! I think an ability to support DDP training with VLMs is super important. Didn't think about it when adding the model but having pixels of shapes that can't be sliced to batches is quite inconvenient. Interesting that we didn't catch indexing errors in tests, but I'll take a closer look at that
is not to concat the pixel values but instead stack them with 0 padding
Yes, this is similar to things we do in llava-next or idefics, where images can end up with different shapes. And in case of llava-next we do unpadding, while idefics can simply mask out padding tokens. I'm pro of reworking Qwen2-VL and it seems like we should be able to mask-out padding tokens so we don't perform unpadding ops. Not very sure on that, will need to look in more details.
@ArthurZucker WDYT about padding images to max-shape in processing?
If it's any help, my workaround looks like this in the preprocessor:
# The following code is a hack to make sure the scatter in DDP is done correctly when training on multiple GPUs
offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]
# separate pixel_values for each image
pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist())
# pad pixel_values to the same length to be able to make it into a tensor
max_length = max([len(pv) for pv in pixel_values])
pixel_values = [torch.cat([pv,
torch.zeros((max_length - len(pv), pv.shape[1]),
dtype=pv.dtype, device=pv.device)]) for pv in pixel_values]
batch_doc["pixel_values"] = torch.stack(pixel_values)
and then, this in the modeling to reverse the operation:
# The following code is a hack to make sure the scatter in DDP is done correctly when training on multiple GPUs
if "pixel_values" in kwargs:
# compute pixel_values offsets
offsets = kwargs["image_grid_thw"][:, 1] * kwargs["image_grid_thw"][:, 2]
kwargs["pixel_values"] = torch.cat([pv[:o] for pv, o in zip(kwargs["pixel_values"], offsets)], dim=0)
Problem is the efficiency with this hacky approach which essentially just wraps around the issue (but we could go deeper into how pixel_values get sliced later on in the model as you suggest which is probably best)
IDK, there is a lot of context here, but the Qwen2 code supports "ragged" inputs AFAIK, so padding should overall not be required no?
In many use cases, mini batches can be independant of each other, but when they are not (here because of contrastive training), the scatter operation that attempts to evenly distribute previously grouped inputs across GPUs fails because the inputs should NOT be distributed evenly.
The workaround is either (1) to add an extra dimension and pad (2) to have a way to scatter inputs non-evenly (3) modifying the trainer so that inputs are kept independant of each other until the model forward pass and deal with potential sequence length disalignement later.
Again, this occurs because I am training a visual embedding model with Qwen using the HF trainer, so bit of a niche use case (but that does seem to be getting popular)
System Info
transformers
version: 4.45.0.dev0Who can help?
@muellerzr @ArthurZucker @gante
Issue about both the Qwen-VL model and perhaps the trainer so not sure who is best suited to answer :)
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Replicating the setup is a bit tough, so this is more of a preliminary discussion issue to see if there is an obvious problem that surfaces.
We observe that compared to mono-gpu setups, the rope values are disaligned with the hidden_states size.
Typically, in line 1109 (Qwen2VisionTransformerPretrainedModel forward pass):
we can see rotary_pos_emb is hidden_states have a sligtly different dimension 0. ex: torch.Size([7820, 40]) torch.Size([7736, 1280])
Upon further inspection, we see rotary_pos_emb has the same dimension as what we would get in mono-gpu runs (normal since it only depends on the grid_thw argument). However, hidden_states (that correspond to pixel values) have a different size.
This makes training crash:
Expected behavior
[edited] see below for more details being investigated
Thanks !