NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
236 stars 43 forks source link

haveDifferentSharding #1656

Open naoyam opened 5 months ago

naoyam commented 5 months ago

https://github.com/NVIDIA/Fuser/blob/da7c4e97e3b90267a5465de856afc5e8d64fc56a/csrc/multidevice/utils.cpp#L47

It seems to me that supporting all potential patterns of parallelization would effectively require the same analysis we do for TID/BID (sync_information.cpp). It would be fine to have some simplified analysis for now, but I think we should at least make it clear what are assumed and have proper assertions.

samnordmann commented 5 months ago

https://github.com/NVIDIA/Fuser/blob/da7c4e97e3b90267a5465de856afc5e8d64fc56a/csrc/multidevice/utils.cpp#L47

  • Does it need to look at all tensor inputs together?

it doesn't strictly have to, but by doing so we can save to recompute the computeAtMap and the IDs at every tensor. I also thought it is strictly more general to accept a set of tensors instead of a single one, and probably makes the code a bit nicer.

Do you want me to change that?

  • What happens if there's no mapped ID, i.e., line 70 never returns true? There can be still cases requiring resharding but looks like we would miss that.

I didn't think about this case. Do you have an example in mind?

It seems to me that supporting all potential patterns of parallelization would effectively require the same analysis we do for TID/BID (sync_information.cpp). It would be fine to have some simplified analysis for now, but I think we should at least make it clear what are assumed and have proper assertions.

I did the best I could with my understanding of the analysis, and this was already not so easy for me... :) This function works at least on our tests cases. But I'd be very happy if we can improve it! I would need some indications though.

Same for the assumptions, Can you help me figure out what are the implicit assumptions here? and what could make the function break? One thing I can say is that, for now, we don't allow any split/merge and only allow the first axis to be parallelized accross devices. Assertions could be added to ensure that, like the ones inside isSharded

naoyam commented 5 months ago

https://github.com/NVIDIA/Fuser/blob/da7c4e97e3b90267a5465de856afc5e8d64fc56a/csrc/multidevice/utils.cpp#L47

  • Does it need to look at all tensor inputs together?

it doesn't strictly have to, but by doing so we can save to recompute the computeAtMap and the IDs at every tensor. I also thought it is strictly more general to accept a set of tensors instead of a single one, and probably makes the code a bit nicer.

As for ComputeAtMap, it's expensive to build, so we should built it once for the whole fusion and reuse it for all expressions.

Do you want me to change that?

I don't have a strong opinion here. I just don't see why we would need to see all inputs together.

  • What happens if there's no mapped ID, i.e., line 70 never returns true? There can be still cases requiring resharding but looks like we would miss that.

I didn't think about this case. Do you have an example in mind?

It seems to me that supporting all potential patterns of parallelization would effectively require the same analysis we do for TID/BID (sync_information.cpp). It would be fine to have some simplified analysis for now, but I think we should at least make it clear what are assumed and have proper assertions.

I did the best I could with my understanding of the analysis, and this was already not so easy for me... :) This function works at least on our tests cases. But I'd be very happy if we can improve it! I would need some indications though.

Same for the assumptions, Can you help me figure out what are the implicit assumptions here? and what could make the function break? One thing I can say is that, for now, we don't allow any split/merge and only allow the first axis to be parallelized accross devices. Assertions could be added to ensure that, like the ones inside isSharded

I think the main assumption here is exactly what you mentioned.

we don't allow any split/merge

If that's the case, let's have an assertion. If we want to make sure the whole tensor has no split/merge, we can just do: tv->getLeafDomain() == tv->getRootDomain(). Also, a simpler and sufficient approach here may be just using PairwiseRootDomainMap instead of ComputeAtMap. The former can only analyze root domains of a producer-consumer pair, but that may be enough.

Can you create a list of all supported patterns? Then we could think about what checks would be necessary.

naoyam commented 5 months ago

@samnordmann Please fix the function comment. It says it returns tvs with the same sharding, which doesn't seem to be the case.

naoyam commented 5 months ago

Let's change the function to return true by default and only return false when we find they have the same sharding. Currently, it returns true if it's found to have different sharding. For correctness, we should generally make conservative decisions.

naoyam commented 5 months ago

Since we assume no merge/split, let's use PairwiseRootDomainMap instead of ComputeAtMap. Something like this should work:

const auto p2c_map = PairwiseRootDomainMap(producer, consumer).mapProducerToConsumer();
for (auto p_id: producer->getMaybeRfactorDomain()) {
  if (p_id->getParallelType() != ParallelType::DIDx) {
     continue;
  }
  if (auto p2c_map_it = p2c_map.find(p_id); p2c_map_it != p2c_map.end()) {
    auto c_id = p2c_map_it->second;
    if (p_id->getParallelType() != c_id->getParallelType()) {
       // Mismatch found
       return true;
    }
  } else {
     // no matching ID found
    return true;
  }
}

I think we should also make sure all DID-parallelized consumer IDs have corresponding producer IDs. The above code only loops over the producer IDs, so that doesn't guarantee anything about it.

samnordmann commented 5 months ago

Since we assume no merge/split, let's use PairwiseRootDomainMap instead of ComputeAtMap. Something like this should work:

Why is it better to use PairwiseRootDomainMap ? The problem I have with this method is that it assumes that we compare a producer and a consumer, but this sounds like a hard and unnecessary restriction

Since we assume no merge/split

It is true that we assume this for now, but the plan is to support it eventually, and we might even need it pretty soon. So I would like to understand what will break in this case / how to fix it

I think we should also make sure all DID-parallelized consumer IDs have corresponding producer IDs. The above code only loops over the producer IDs, so that doesn't guarantee anything about it.

Correct me if Im wrong but I think that this is checked in the current implementation (in a "permissive" way, because we want to allow bcast)

samnordmann commented 5 months ago

@samnordmann Please fix the function comment. It says it returns tvs with the same sharding, which doesn't seem to be the case. Let's change the function to return true by default and only return false when we find they have the same sharding. Currently, it returns true if it's found to have different sharding. For correctness, we should generally make conservative decisions.

The function doesn't return a boolean but indeed a set of TensorView as indicated in the comment. here https://github.com/NVIDIA/Fuser/pull/1632 I renamed haveDifferentSharding to getTvsWithDifferentSharding, I hope it is clearer like that

naoyam commented 5 months ago

@samnordmann Please fix the function comment. It says it returns tvs with the same sharding, which doesn't seem to be the case. Let's change the function to return true by default and only return false when we find they have the same sharding. Currently, it returns true if it's found to have different sharding. For correctness, we should generally make conservative decisions.

The function doesn't return a boolean but indeed a set of TensorView as indicated in the comment. here #1632 I renamed haveDifferentSharding to getTvsWithDifferentSharding, I hope it is clearer like that

My point was we should make conservative decisions. By default, a tensor should be considered to have different sharding. Only when we can prove that's not the case, we should exclude it from the returned set.

naoyam commented 5 months ago

Since we assume no merge/split, let's use PairwiseRootDomainMap instead of ComputeAtMap. Something like this should work:

Why is it better to use PairwiseRootDomainMap ? The problem I have with this method is that it assumes that we compare a producer and a consumer, but this sounds like a hard and unnecessary restriction

Because it doesn't need to do the analysis of ComputeAtMap, which scans the whole fusion, whereas PairwiseRootDomainMap is just a quick analysis of a producer and consumer pair of tensors. The latter is a much cheaper operation, and that's all we need here.

Since we assume no merge/split

It is true that we assume this for now, but the plan is to support it eventually, and we might even need it pretty soon. So I would like to understand what will break in this case / how to fix it

The things we need to consider can be mostly found in device_lower/analysis/sync_information. We should extend that rather than creating a new one here when we support merge and split. Extending that should be relatively straightforward since all the necessary logic is already implemented and we just need to add DIDx.

Also, sync_information will be simplified in the future with IdModel. I can't say when it happens, but I'd suggest delaying touching that file as late as possible.

I think we should also make sure all DID-parallelized consumer IDs have corresponding producer IDs. The above code only loops over the producer IDs, so that doesn't guarantee anything about it.

Correct me if Im wrong but I think that this is checked in the current implementation (in a "permissive" way, because we want to allow bcast)

I was referring to the code snippet I wrote above.

samnordmann commented 5 months ago

@naoyam In the function we do not make any assumption that ref is the producer / consumer of the tvs. Therefore I don't know how to use PairwiseRootDomainMap. Do you have a suggestion?

naoyam commented 5 months ago

Isn't the function always used with producers and consumers of an expr? That seems to be the case with the two use cases in multidevice/utils.cpp.