The issue is that I get MPMD detected. It means that at some point at least 2 workers try to execute different graphs. So I tried to check the diff between the two HLO graphs. I ran the script multiple times, I cannot say I always end-up with the same diff, but at least multiple times I ended up with this:
Basically:
In one case, one parameter is the input to a select op.
In the other case, we have 2 parameters: the same one as in the first case, let's call it p, and a scalar., let's call it s . Then the input to the select op is in this case: p - broadcast(s).
After analyzing it a bit, I think it this computation comes from the ParallelEmbedding layer. For some reason what is considered a constant equal to 0 in one case, is considered a parameter in the other case.
I thought it could be linked to scalar specialization by XLA so I also ran the job with XLA_NO_SPECIAL_SCALARS=1 but ended up with a MPMD detected error as well.
So I tried not to use ParallelEmbedding. When sequence parallelism is enabled I end-up with:
In one case it does [16, 64] -> [1, 16, 64]. So here it seems to be B x S x H. Then it adds a reshape at the end to become S / 2 x B x H.
And in the other case [1, 64] -> [16, 1, 64] . Here it is S x B x H. And then we end-up with S / 2 x B x H.
Finally, I tried disabling sequence parallelism and ended-up with:
Note: when I disable tensor parallelism it seems to be working properly.
So basically I am trying to train LLama / Mistral.
I run the following command:
Here is the link to train_mistral.sh
The issue is that I get MPMD detected. It means that at some point at least 2 workers try to execute different graphs. So I tried to check the diff between the two HLO graphs. I ran the script multiple times, I cannot say I always end-up with the same diff, but at least multiple times I ended up with this:
Basically:
In one case, one parameter is the input to a select op.
In the other case, we have 2 parameters: the same one as in the first case, let's call it p, and a scalar., let's call it s . Then the input to the select op is in this case:
p - broadcast(s)
.After analyzing it a bit, I think it this computation comes from the
ParallelEmbedding
layer. For some reason what is considered a constant equal to0
in one case, is considered a parameter in the other case.I thought it could be linked to scalar specialization by XLA so I also ran the job with
XLA_NO_SPECIAL_SCALARS=1
but ended up with aMPMD detected
error as well.So I tried not to use
ParallelEmbedding
. When sequence parallelism is enabled I end-up with:[16, 64] -> [1, 16, 64]
. So here it seems to beB x S x H
. Then it adds areshape
at the end to becomeS / 2 x B x H
.[1, 64] -> [16, 1, 64]
. Here it isS x B x H
. And then we end-up withS / 2 x B x H
.Finally, I tried disabling sequence parallelism and ended-up with:
Note: when I disable tensor parallelism it seems to be working properly.