Open kevinstephano opened 1 year ago
This is a great example as it's used in any multiclass classification problem or learned compression network. Should we focus on separate forward and backward or also look at the "turnaround" fusion? In this case the backward is just a softmax minus one-hot, and the softmax is much simpler if we hold on to the log_softmax from the forward computation.
“gathering” it to a size of [64, 128, 1]
Is it always to a size of 1? Looks like so as it's followed by squeeze. Assuming yes, is it completely unknown which value is gathered? Could it be always, for example, the first value?
Asking as a generic implementation of gather would need to have conservative assumptions, for example, in this case the output of log softmax, [64, 128, 32768]
, would need to be saved to global memory, followed by a grid sync, and then the gather op would be safely executed. Obviously, this would make no sense for this case as it wouldn't save any memory write.
Yeah always size 1, and it can be any of the values 0-32767. The index is the true label for an example, and the tensor we're indexing would be the predicted log probabilities (logits) for each class.
The fact that it's always size 1 would be valuable for code generation. If we could assume it, we could take a different code generation strategy. That information would need to be communicated to the backend somehow. What would be the best way? It seems to me that if a frontend could translate gather to size 1 + squeeze
to select
(which should be equivalent, right?), then the backend wouldn't need to be as conservative as to support generic gather semantics.
torch.gather(Z, 1, Y.unsqueeze(1)).squeeze()
would be the common pattern, where Z
here is torch.log_softmax(logits, 1)
, Z
is of size [N, C] and Y is of size [N].
Hahaha, turns out I was dumb and totally speaking non-sense with how cross-entropy loss could be done with numpy.take..... :face_exhaling: Thanks to @jacobhinkle for kindly pointing it out.
It seems to me that if a frontend could translate gather to size 1 + squeeze to select (which should be equivalent, right?), then the backend wouldn't need to be as conservative as to support generic gather semantics.
Having said that, if I'm reading this ^^^ correctly
We don't need to expose this at user facing API. As long as we put this take_one_along_axis
in codegen, we should be good. And that decision can be made at compile time, since we special case on size-1 / broadcast.
We can hide torch_gather
behind take_along_axis
. https://github.com/NVIDIA/Fuser/blob/03b02950161913c579fd053f694b4a668f8f6e99/csrc/ops/arith.cpp#L198-L237
TensorView torch_gather(TensorView* inp, int dim, TensorView* index) {
auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
auto idx_domain = TensorDomain::noReductions(index->getMaybeRFactorDomain());
...
# so the only
if (idx_domain[dim]->isBroadcast()) {
IrBuilder::create<TakeOneAlongAxis>((...)
} else {
IrBuilder::create<TorchGatherOp>(...)
}
}
I think with this, we should be able to support numpy.take
and numpy.take_along_axis
efficiently. Meanwhile, with the two numpy operation, we can cover embedding as well as CrossEntropyLoss.
translate gather to size 1 + squeeze to select (which should be equivalent, right?)
hmmm. wait 1 sec here. are we referring to the existing select
op in arith.cpp? select
only supports a single scalar index so that won't do.
jacob mentioned it here
Z is of size [N, C] and Y is of size [N].
so input is of shape [N, C], while index is of shape [N], we want an output of shape [N]. (Note that this is a crossentropy loss actually supports arbitrary rank loss, but that's not really important in this discussion: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). I don't think this is what select
can do.
I know we've had this conversation a few times in separate context and groups. Lol, let's go over this on Monday and make sure that we are on the same page regarding:
Oh, I see. I thought the problem was much simpler. Looks like we do need to have a simplified version of torch.gather
, which is:
whereas in the real torch.gather
:
https://pytorch.org/docs/stable/generated/torch.gather.html
It may actually be relatively trivial to lift the second assumption as I think it's just cutting off the final inpu.size(d) - index.size(d)
elements, so we could just combine slice
.
In any case, this isn't as trivial as I initially thought.
One question on gathering from [64, 128, 32768]
to [64, 128, 1]
. A straightforward implementation would save the gather input to global memory, do a grid sync, and then do the gather. This could be relatively easily done, but the grid sync would mean the cooperative launch. Is this a reasonable approach? Is there any other better approach?
It might complicate things, but I'll note that this pattern when used in a loss is typically followed immediately by a reduction, so that the combo could be implemented by an iota() + where() composed with a sum. Would that allow us to avoid writing to global memory, and the cooperative launch?
Probably not. The reason we would need to use global memory is that the input to the gather op may be parallelized by blockIdx and threadIdx. In general, in order to allow ops after gather to be parallelized independently from the ops before the gather, the gather input needs to be accessible by any threads, which means global memory followed.
For example:
auto tv0 = [128]; // gather input
auto tv1 = [1]; // gather index
auto tv2 = gather(tv0, tv1);
auto tv3 = exp(tv2);
Let's say tv0
is scheduled as:
tv0->split(0, 32);
tv0->axis(0)->parallelize(ParallelType::BIDx);
tv0->axis(1)->parallelize(ParallelType::TIDx);
The sizes tv2
and tv3
are just [1]
, so if we want to use just one thread, threadIdx.x==0 && blockIdx.x == 0
, then the code would look like:
if (threadIdx.x == 0 && blockIdx.x == 0) {
tv2[0] = tv0[tv1[0]];
tv3[0] = exp(tv2[0]);
}
The problem here is that tv0
is parallelized by threadIdx.x
and blockIdx.x
, so if it's stored in the register file, this is invalid, and furthermore the writes to tv0
on global memory must be flushed and all threads must be synchronized.
Alternatively, the below would work without using global memory:
if (threadIdx.x + blockIdx.x * blockDim.x == tv1[0]) {
tv2[0] = tv0[tv1[0]];
tv3[0] = exp(tv2[0]);
}
Note that all dependent ops after tv2
also need to be predicated in the same way. To remove the predicate, it needs to do some type of communications to distribute the value to all other threads, for which one simple way is to save it to global memory and synchronize. So, if the ops following gather is trivial and doesn't need to be parallelized, this latter approach might make sense.
Interesting. In the special case where the gather output has a single use which is a reduction including the gather axis, rewriting the graph from
tv2 = torch_gather(tv1, tv_index, dim); // assume tv_index has size 1 in dimension dim
tv3 = sum(tv2, {dim});
to
tv4 = iota(tv1->axis(dim)->extent());
tv5 = broadcast(... // bcast along all dimensions other than dim
tv6 = where(eq(tv5, tv_index), tv1, zeroVal); // forms a "one-hot" TensorView
tv7 = sum(tv6, {dim});
seems like it would compute the right thing even though the reduction is trivial. In that case we know tv6 has only the one use in the following sum, so the dim axis can be parallelized the same way as tv1, as it's just a product of pointwise ops.
A trivial question regarding this:
index.size(d) == input.size(d) for all dimensions d != dim
vs
index.size(d) <= input.size(d) for all dimensions d != dim
Can't the second part be handled with a Slice prior to the gather op? would having torch_gather
gives us better code, comparing to a slice
+ somewhat simplified gather
?
Interesting. In the special case where the gather output has a single use which is a reduction including the gather axis, rewriting the graph from
tv2 = torch_gather(tv1, tv_index, dim); // assume tv_index has size 1 in dimension dim tv3 = sum(tv2, {dim});
to
tv4 = iota(tv1->axis(dim)->extent()); tv5 = broadcast(... // bcast along all dimensions other than dim tv6 = where(eq(tv5, tv_index), tv1, zeroVal); // forms a "one-hot" TensorView tv7 = sum(tv6, {dim});
seems like it would compute the right thing even though the reduction is trivial. In that case we know tv6 has only the one use in the following sum, so the dim axis can be parallelized the same way as tv1, as it's just a product of pointwise ops.
Oh, that translation is interesting. The final sum reduction would still need to be a grid reduction, even though most of the contribution by each thread should be zero, but it's certainly more efficient than writing to global memory, memory flush, global sync, and reading it again.
We should definitely try this formulation as well.
A trivial question regarding this:
index.size(d) == input.size(d) for all dimensions d != dim
vs
index.size(d) <= input.size(d) for all dimensions d != dim
Can't the second part be handled with a Slice prior to the gather op? would having
torch_gather
gives us better code, comparing to aslice
+somewhat simplified gather
?
Ideally, it should not, but in reality I'm not sure. We haven't looked at the performance of these ops, so I'm pretty sure there's a lot to consider.
🚀 The feature, motivation and pitch
The task is to fuse
log_softmax+gather
. Naoya said it depends on his resize function work. The idea being that the tensor output of log_softmax can be roughly[64, 128, 32768]
which in float is ~1GB. It is expensive to re-read that tensor versus “gathering” it to a size of[64, 128, 1]
which is a trivially sized tensor. There are some people working onindex_select
,gather
, andscatter
but they have been only allowed to fuse them as the first operation of fusion. Thegather
, in this instance, would be at the end of the fusion.CrossEntropyLoss
forward
includes alog_softmax
followed by agather
operation.Is is notably used in NLP networks like Bert as seen here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1139-L1143
Code example:
How to view graph?
This section in particular is the
log_softmax + gather
. This is from printing outtorch.compile
's graph.