openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.7k stars 434 forks source link

Channel ids for collectives are not unique when custom sharding is inlined in XLA #14600

Open Tixxx opened 4 months ago

Tixxx commented 4 months ago

We are seeing a compilation crash with a custom partitioning defined by the user. I'm attaching the repro script and instructions to repro. The error happens when running with TransformerEngine in our jax toolbox docker container, although I'd think this would happen for any inlined custom sharding. My repro needs an 8-gpu machine to run. here's the smalle repro script(https://github.com/Tixxx/fileshare/blob/main/scan_unbound_nvbug_with_te_refactor_v2.py). Instructions: docker run --gpus all -it ghcr.io/nvidia/jax:pax-2024-06-18 python scan_unbound_nvbug_with_te_refactor_v2.py

The error is "INTERNAL: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:2494) first->opcode() == instr->opcode() channel 1 is used for different types of channel instructions"

Apparently, spmd paritioner assigns the same channel id to 2 different collectives. When the spmd partitioner tries to inline the custom partitioning call, it assigns channel ids to collectives in that computation incrementally, but when it tries to create collectives for sharded instruction in the main computation, the state of the increment is lost and the channels starts from the original value again. The state is passed around in numerous places when creating collectives in the partitioner. I haven't been able to pin point where the state is lost. But maybe we can do a post processing of the partitioned module to scan through collectives, detect duplicate channel ids and re-assign them with unique ones.

ptoulme-aws commented 4 months ago

i have noticed this also when unrolling while loops with custom-calls inside of them. My solution was to add a channel id legalizer pass that enforces unique channel id for each collective

Tixxx commented 4 months ago

i have noticed this also when unrolling while loops with custom-calls inside of them. My solution was to add a channel id legalizer pass that enforces unique channel id for each collective

Thanks. I was thinking along the same line to have an uniquifyer to post-process the graph.

nouiz commented 4 months ago

@ptoulme-aws Any hope of upstreaming or making publicly available your fix?

ptoulme-aws commented 4 months ago

I will PR the pass tonight. Best, Patrick Toulme

On Jul 16, 2024, at 5:38 PM, Frédéric Bastien @.***> wrote:



@ptoulme-awshttps://github.com/ptoulme-aws Any hope of upstreaming or making publicly available your fix?

— Reply to this email directly, view it on GitHubhttps://github.com/openxla/xla/issues/14600#issuecomment-2232070709, or unsubscribehttps://github.com/notifications/unsubscribe-auth/BALTS7PSA4HJPN5C5F2HMCTZMW4GZAVCNFSM6AAAAABKRNK2FOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMZSGA3TANZQHE. You are receiving this because you were mentioned.Message ID: @.***>

ptoulme-aws commented 3 months ago

This PR should unblock you - https://github.com/openxla/xla/pull/15002

ptoulme-aws commented 3 months ago

@Tixxx My PR merged. can we close this issue now or leave it open for SPMD debug?

I would like to mention "i have noticed this also when unrolling while loops with custom-calls inside of them." if we leave this open for debug.

Tixxx commented 3 months ago

Let's leave it open for now. I think we need to revisit where the pass needs to be run. In our case, the duplicated id happens right after spmd partitioner, with hlo verifier running in different places in the pipeline, it will error out really early. Also it seems like the pr was rolled back.

ptoulme-aws commented 3 months ago

@Tixxx I have new PR. Where should we add it in GPU compiler? This will fix the peer to peer failure.

Tixxx commented 3 months ago

OK great, thanks, I think we will need to run it right after spmd partitioner, the error reported in the bug is caused by the partitioner giving duplicated IDs. @frgossen Do you think this should be run as a sub-pass of spmd partitioner or as a stand-alone pass after spmd pipeline? I remember the hlo verifier is run after spmd so we might still get into the same error if running as a stand-alone pass.

frgossen commented 3 months ago

We had a little bit of an offline discussion and the right place to fix this would be wherever the inlining happens. The channel ids should be unique at any point and changing them will cause problems for MPMD compilations. Note that the channel id is completely irrelevant for SMPD programs. I

Tixxx commented 3 months ago

Ok does that mean for the problem described in this issue, we won't rely on the channel id legalizer pass?

Tixxx commented 3 months ago

We had a little bit of an offline discussion and the right place to fix this would be wherever the inlining happens. The channel ids should be unique at any point and changing them will cause problems for MPMD compilations. Note that the channel id is completely irrelevant for SMPD programs. I

@frgossen Is there anyone from google already looking into fixing this when inlining happens?

frgossen commented 3 months ago

I don't think anyone is looking into this atm. But I'm happy to review PRs that fix this