Closed mars1248 closed 4 months ago
I think wrap to tuple is due to https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L1290-L1291. On TPU we can't have more than 3200 HLO input paramters so we wrap them into a tuple. I would image SPMD still works through. Do you have the HLO for the wrapped case?
@JackCaoG Thanks for your reply, I managed to run it successfully by turning up XLA_PARAMETER_WRAPPING_THREADSHOLD. I see the bug is documented here, and the todo is documented. https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L1307 What is the function of tuple input?
That TODO Is actually already fixed, SPMD graph input and output can correctly be aliased. wrapping happens https://github.com/pytorch/xla/blob/7938bb5da6c993609aff614ccfa5b722a339d158/torch_xla/csrc/helpers.cpp#L974-L1005
https://github.com/pytorch/xla/pull/7604 should fix this issue
š Bug
To Reproduce
Steps to reproduce the behavior: test.sh
test_multi_param_layer.py
Expected behavior
Environment
In the preceding example, if you change the number of linear to 10, it will work, but if you change it to 400, you will get an error. I observed that on the second compilation, all the input tensors were compressed into a tuple, and the sharding information was lost after the after compile optimization
Additional context