csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Make contiguity ignore broadcasts #2517

Closed zasdfgbnm closed 1 year ago

zasdfgbnm commented 1 year ago

Currently, contiguity had the same size as the rfactor domain of a tensor. And in this PR, I am changing it to the size of TensorDomain::noBroadcasts(rfactor_dom). With this change, "the contiguity of a broadcast/expand dimension" is no longer meaningful by definition. And contiguity[i] is true if and only if it is memory dense with the next non-broadcasting dimension.

For example, if I have a tensor torch.zeros(4, 1, 3).expand(-1, 10, -1), before this change, the contiguity of this tensor will be (false, false, true), and after the change, it will be (true, true).

The reason for doing this change is, we are more interested in whether a non-broadcasting dimension is memory dense with its next non-broadcasting dimension. In the example above, we are interested in whether 4 and 3 is memory dense. We are not interested in if 10 and 3 are memory dense, because by definition they are trivially not. In this example, we want to vectorize 4, however, the current contiguity design is blocking us from doing so.

Currently, our definition about the contiguity of the broadcasting dimensions and the dimension before a broadcasting dimension is vague and not well formalized. For example, if I have shape (4, 1, 6), stride (4*999999, 999999, 1), then on the one hand, our system will calculate its contiguity as (true, false, true), however, on the other hand, our indexing will collapse the index of dim 0 with dim 2 because it ignoring broadcasts (this is the root cause of #2169). I will not consider this an indexing bug. Instead, I consider this as an issue of ambiguity in the definition of contiguity. And my design change is an effort to remove this ambiguity.

See also: #2169, https://github.com/csarofeen/pytorch/pull/2049

Fixes #2169

naoyam commented 1 year ago

It seems that many of the changes are because the contiguity vector now only holds flags only for non-broadcast domains. I wonder if it could be simpler if we kept the contiguity vector to have flags of all domains and just change the definition of the flag. IIRC, if the flag is true, it means the stride of the domain can be calculated as the stride of the next inner domain multiplied by the extent of the inner domain. If we change the definition of the next inner domain to the next non-broadcast inner domain, I think we should be able to have the same benefit of this PR without doing noBroadcastDomains.

Just my two cents.

zasdfgbnm commented 1 year ago

It seems that many of the changes are because the contiguity vector now only holds flags only for non-broadcast domains. I wonder if it could be simpler if we kept the contiguity vector to have flags of all domains and just change the definition of the flag. IIRC, if the flag is true, it means the stride of the domain can be calculated as the stride of the next inner domain multiplied by the extent of the inner domain. If we change the definition of the next inner domain to the next non-broadcast inner domain, I think we should be able to have the same benefit of this PR without doing noBroadcastDomains.

Just my two cents.

I agree that keeping the flag for broadcasting could still have the benefit of ignoring broadcast in its definition. But I don't think it would make this diff easier. The only save is a few noBroadcastDomains, but we would need to change our caching system to ignore the flag value at broadcast dimensions. And I don't like making contiguity storing redundant unused value, because the definition of contiguity is not trivial, and we already made mistakes by writing wrong indexing code. Storing these extra values would make it easy for us to make similar mistakes in the future. Making the size of contiguity deviates from the size of the rfactor domain could lead to much louder errors when we make a mistake, therefore avoid hard-to-catch bugs.

naoyam commented 1 year ago

Hahaha, I expected this answer:

And I don't like making contiguity storing redundant unused value

jjsjann123 commented 1 year ago

Looks like my issues are all resolved. I'm leaving it to @naoyam to stamp on this one.

zasdfgbnm commented 1 year ago

Hahaha, I expected this answer:

And I don't like making contiguity storing redundant unused value

@naoyam Indeed, instead of making contiguity storing redundant unused boolean, we can still make contiguity have the same size as rfactor domain by making contiguity a std::vector<c10::optional<bool>> instead of std::vector<bool>. So for the tensor torch.zeros(8, 1, 3).expand(8, 9, 3), the contiguity will be (true, None, true). I disliked (true, true, true) and (true, false, true) in favor of (true, true) because I wanted to see a hard error when trying to read the contiguity of a broadcast domain because this means a bug and because I don't want to recompile if the contiguity changed from (true, true, true) to (true, false, true). I think (true, None, true) has all the benefits of (true, true) and is more straightforward and more convenient. I was discussing with @jjsjann123 this morning about frontend design, and we both like this approach. Do you think this makes sense? I will write a PR to change it if it makes sense.

naoyam commented 1 year ago

Hahaha, I expected this answer:

And I don't like making contiguity storing redundant unused value

@naoyam Indeed, instead of making contiguity storing redundant unused boolean, we can still make contiguity have the same size as rfactor domain by making contiguity a std::vector<c10::optional<bool>> instead of std::vector<bool>. So for the tensor torch.zeros(8, 1, 3).expand(8, 9, 3), the contiguity will be (true, None, true). I disliked (true, true, true) and (true, false, true) in favor of (true, true) because I wanted to see a hard error when trying to read the contiguity of a broadcast domain because this means a bug and because I don't want to recompile if the contiguity changed from (true, true, true) to (true, false, true). I think (true, None, true) has all the benefits of (true, true) and is more straightforward and more convenient. I was discussing with @jjsjann123 this morning about frontend design, and we both like this approach. Do you think this makes sense? I will write a PR to change it if it makes sense.

Sounds good to me.

jjsjann123 commented 1 year ago

Briefly brought up this conversation with the frontend team. There's some opinion on how we should expose contiguity flag on the frontend. pointing @kevinstephano here for visibility.

jjsjann123 commented 1 year ago

The change on python frontend is linked above in #2561