openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
389 stars 102 forks source link

Implement type inference for AllGatherOp #865

Open burmako opened 1 year ago

burmako commented 1 year ago

To quote the spec, "type(result) = type(operand) except dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)". However, at the moment we cannot compute "dim(process_groups, 1)" because it depends on num_replicas and num_partitions. We'll need to resolve this.

burmako commented 1 year ago

Sorry - I haven't provided a rationale for why I unassigned this ticket.

dim(process_groups, 1) is something that we cannot compute at the moment. It needs num_replicas and num_partitions, and that's not something that is reflected in the StableHLO dialect right now. #425 is planning to fix this by adding this information to the StableHLO opset (and the dialect).

In the meanwhile, while this work is blocked on another ticket, I've unassigned this ticket.