csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Example of Gather/Scatter from CrossEntropyLoss #2556

Open kevinstephano opened 1 year ago

kevinstephano commented 1 year ago

🚀 The feature, motivation and pitch

The task is to fuse log_softmax+gather. Naoya said it depends on his resize function work. The idea being that the tensor output of log_softmax can be roughly [64, 128, 32768] which in float is ~1GB. It is expensive to re-read that tensor versus “gathering” it to a size of [64, 128, 1] which is a trivially sized tensor. There are some people working on index_select, gather , and scatter but they have been only allowed to fuse them as the first operation of fusion. The gather , in this instance, would be at the end of the fusion.

CrossEntropyLoss forward includes a log_softmax followed by a gather operation.

Is is notably used in NLP networks like Bert as seen here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1139-L1143

Code example:

import torch

class MyMod(torch.nn.Module):
    def __init__(self):
        super(MyMod, self).__init__()
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, inputs, targets):
        return self.loss_fn(inputs, targets)

cp_model = torch.compile(MyMod())

inputs = [
    torch.randn(512, 32768, device='cuda', requires_grad=True),
    torch.randint(0, 32768, (512,), device='cuda'),
]

for _ in range(5):
    out = cp_model(*inputs)
    out.backward()

How to view graph?

$ AOT_FX_GRAPHS=1 python test.py 
====== Forward graph 0 ======
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[512, 32768], primals_2: i64[512]):
        # File: /workspace/test.py:9, code: return self.loss_fn(inputs, targets)
        amax: f32[512, 1] = torch.ops.aten.amax.default(primals_1, [1], True)
        sub: f32[512, 32768] = torch.ops.aten.sub.Tensor(primals_1, amax);  primals_1 = amax = None
        exp: f32[512, 32768] = torch.ops.aten.exp.default(sub)
        sum_1: f32[512, 1] = torch.ops.aten.sum.dim_IntList(exp, [1], True);  exp = None
        log: f32[512, 1] = torch.ops.aten.log.default(sum_1);  sum_1 = None
        sub_1: f32[512, 32768] = torch.ops.aten.sub.Tensor(sub, log);  sub = log = None
        ne: b8[512] = torch.ops.aten.ne.Scalar(primals_2, -100)
        scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
        where: i64[512] = torch.ops.aten.where.self(ne, primals_2, scalar_tensor);  scalar_tensor = None
        unsqueeze: i64[512, 1] = torch.ops.aten.unsqueeze.default(where, 1);  where = None
        gather: f32[512, 1] = torch.ops.aten.gather.default(sub_1, 1, unsqueeze);  unsqueeze = None
        squeeze: f32[512] = torch.ops.aten.squeeze.dim(gather, 1);  gather = None
        neg: f32[512] = torch.ops.aten.neg.default(squeeze);  squeeze = None
        scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
        where_1: f32[512] = torch.ops.aten.where.self(ne, neg, scalar_tensor_1);  neg = scalar_tensor_1 = None
        sum_2: i64[] = torch.ops.aten.sum.default(ne);  ne = None
        convert_element_type: f32[] = torch.ops.prims.convert_element_type.default(sum_2, torch.float32);  sum_2 = None
        sum_3: f32[] = torch.ops.aten.sum.default(where_1);  where_1 = None
        div: f32[] = torch.ops.aten.div.Tensor(sum_3, convert_element_type);  sum_3 = None
        return [div, primals_2, sub_1, convert_element_type]

====== Backward graph 0 ======
class GraphModule(torch.nn.Module):
    def forward(self, primals_2: i64[512], sub_1: f32[512, 32768], convert_element_type: f32[], tangents_1: f32[]):
        # File: /workspace/test.py:9, code: return self.loss_fn(inputs, targets)
        scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
        scalar_tensor_1: f32[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
        div_1: f32[] = torch.ops.aten.div.Tensor(tangents_1, convert_element_type);  tangents_1 = convert_element_type = None
        unsqueeze_1: i64[512, 1] = torch.ops.aten.unsqueeze.default(primals_2, 1);  primals_2 = None
        ne_3: b8[512, 1] = torch.ops.aten.ne.Scalar(unsqueeze_1, -100)
        where_2: i64[512, 1] = torch.ops.aten.where.self(ne_3, unsqueeze_1, scalar_tensor);  unsqueeze_1 = scalar_tensor = None
        full_like: f32[512, 32768] = torch.ops.aten.full_like.default(sub_1, 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False, memory_format = torch.preserve_format)
        scatter: f32[512, 32768] = torch.ops.aten.scatter.value(full_like, 1, where_2, -1.0);  full_like = where_2 = None
        where_3: f32[512, 1] = torch.ops.aten.where.self(ne_3, div_1, scalar_tensor_1);  ne_3 = div_1 = scalar_tensor_1 = None
        mul: f32[512, 32768] = torch.ops.aten.mul.Tensor(scatter, where_3);  scatter = where_3 = None
        exp_1: f32[512, 32768] = torch.ops.aten.exp.default(sub_1);  sub_1 = None
        sum_4: f32[512, 1] = torch.ops.aten.sum.dim_IntList(mul, [1], True)
        mul_1: f32[512, 32768] = torch.ops.aten.mul.Tensor(exp_1, sum_4);  exp_1 = sum_4 = None
        sub_2: f32[512, 32768] = torch.ops.aten.sub.Tensor(mul, mul_1);  mul = mul_1 = None
        return [sub_2, None]

This section in particular is the log_softmax + gather. This is from printing out torch.compile's graph.

        amax: f32[512, 1] = torch.ops.aten.amax.default(primals_1, [1], True)
        sub: f32[512, 32768] = torch.ops.aten.sub.Tensor(primals_1, amax);  primals_1 = amax = None
        exp: f32[512, 32768] = torch.ops.aten.exp.default(sub)
        sum_1: f32[512, 1] = torch.ops.aten.sum.dim_IntList(exp, [1], True);  exp = None
        log: f32[512, 1] = torch.ops.aten.log.default(sum_1);  sum_1 = None
        sub_1: f32[512, 32768] = torch.ops.aten.sub.Tensor(sub, log);  sub = log = None
        ne: b8[512] = torch.ops.aten.ne.Scalar(primals_2, -100)
        scalar_tensor: i64[] = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
        where: i64[512] = torch.ops.aten.where.self(ne, primals_2, scalar_tensor);  scalar_tensor = None
        unsqueeze: i64[512, 1] = torch.ops.aten.unsqueeze.default(where, 1);  where = None
        gather: f32[512, 1] = torch.ops.aten.gather.default(sub_1, 1, unsqueeze);  unsqueeze = None
jacobhinkle commented 1 year ago

This is a great example as it's used in any multiclass classification problem or learned compression network. Should we focus on separate forward and backward or also look at the "turnaround" fusion? In this case the backward is just a softmax minus one-hot, and the softmax is much simpler if we hold on to the log_softmax from the forward computation.

naoyam commented 1 year ago

“gathering” it to a size of [64, 128, 1]

Is it always to a size of 1? Looks like so as it's followed by squeeze. Assuming yes, is it completely unknown which value is gathered? Could it be always, for example, the first value?

Asking as a generic implementation of gather would need to have conservative assumptions, for example, in this case the output of log softmax, [64, 128, 32768], would need to be saved to global memory, followed by a grid sync, and then the gather op would be safely executed. Obviously, this would make no sense for this case as it wouldn't save any memory write.

jacobhinkle commented 1 year ago

Yeah always size 1, and it can be any of the values 0-32767. The index is the true label for an example, and the tensor we're indexing would be the predicted log probabilities (logits) for each class.

naoyam commented 1 year ago

The fact that it's always size 1 would be valuable for code generation. If we could assume it, we could take a different code generation strategy. That information would need to be communicated to the backend somehow. What would be the best way? It seems to me that if a frontend could translate gather to size 1 + squeeze to select (which should be equivalent, right?), then the backend wouldn't need to be as conservative as to support generic gather semantics.

jacobhinkle commented 1 year ago

torch.gather(Z, 1, Y.unsqueeze(1)).squeeze() would be the common pattern, where Z here is torch.log_softmax(logits, 1), Z is of size [N, C] and Y is of size [N].

jjsjann123 commented 1 year ago

Hahaha, turns out I was dumb and totally speaking non-sense with how cross-entropy loss could be done with numpy.take..... :face_exhaling: Thanks to @jacobhinkle for kindly pointing it out.

It seems to me that if a frontend could translate gather to size 1 + squeeze to select (which should be equivalent, right?), then the backend wouldn't need to be as conservative as to support generic gather semantics.

Having said that, if I'm reading this ^^^ correctly

We don't need to expose this at user facing API. As long as we put this take_one_along_axis in codegen, we should be good. And that decision can be made at compile time, since we special case on size-1 / broadcast.

We can hide torch_gather behind take_along_axis. https://github.com/NVIDIA/Fuser/blob/03b02950161913c579fd053f694b4a668f8f6e99/csrc/ops/arith.cpp#L198-L237

TensorView torch_gather(TensorView* inp, int dim, TensorView* index) {
  auto inp_domain = TensorDomain::noReductions(inp->getMaybeRFactorDomain());
  auto idx_domain = TensorDomain::noReductions(index->getMaybeRFactorDomain());

  ...

  # so the only 
  if (idx_domain[dim]->isBroadcast()) {
    IrBuilder::create<TakeOneAlongAxis>((...)
  } else {
    IrBuilder::create<TorchGatherOp>(...)
  }
}

I think with this, we should be able to support numpy.take and numpy.take_along_axis efficiently. Meanwhile, with the two numpy operation, we can cover embedding as well as CrossEntropyLoss.

jjsjann123 commented 1 year ago

translate gather to size 1 + squeeze to select (which should be equivalent, right?)

hmmm. wait 1 sec here. are we referring to the existing select op in arith.cpp? select only supports a single scalar index so that won't do.

jacob mentioned it here

Z is of size [N, C] and Y is of size [N].

so input is of shape [N, C], while index is of shape [N], we want an output of shape [N]. (Note that this is a crossentropy loss actually supports arbitrary rank loss, but that's not really important in this discussion: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). I don't think this is what select can do.

I know we've had this conversation a few times in separate context and groups. Lol, let's go over this on Monday and make sure that we are on the same page regarding:

  1. what backend/codegen needed and can do in order to support fusion of CrossEntropyLoss;
  2. are we ok with moving forward with numpy.take + numpy.take_along_axis;
naoyam commented 1 year ago

Oh, I see. I thought the problem was much simpler. Looks like we do need to have a simplified version of torch.gather, which is:

whereas in the real torch.gather:

https://pytorch.org/docs/stable/generated/torch.gather.html

It may actually be relatively trivial to lift the second assumption as I think it's just cutting off the final inpu.size(d) - index.size(d) elements, so we could just combine slice.

In any case, this isn't as trivial as I initially thought.

One question on gathering from [64, 128, 32768] to [64, 128, 1]. A straightforward implementation would save the gather input to global memory, do a grid sync, and then do the gather. This could be relatively easily done, but the grid sync would mean the cooperative launch. Is this a reasonable approach? Is there any other better approach?

jacobhinkle commented 1 year ago

It might complicate things, but I'll note that this pattern when used in a loss is typically followed immediately by a reduction, so that the combo could be implemented by an iota() + where() composed with a sum. Would that allow us to avoid writing to global memory, and the cooperative launch?

naoyam commented 1 year ago

Probably not. The reason we would need to use global memory is that the input to the gather op may be parallelized by blockIdx and threadIdx. In general, in order to allow ops after gather to be parallelized independently from the ops before the gather, the gather input needs to be accessible by any threads, which means global memory followed.

For example:

auto tv0 = [128]; // gather input
auto tv1 = [1]; // gather index
auto tv2 = gather(tv0, tv1);
auto tv3 = exp(tv2);

Let's say tv0 is scheduled as:

tv0->split(0, 32);
tv0->axis(0)->parallelize(ParallelType::BIDx);
tv0->axis(1)->parallelize(ParallelType::TIDx);

The sizes tv2 and tv3 are just [1], so if we want to use just one thread, threadIdx.x==0 && blockIdx.x == 0, then the code would look like:

 if (threadIdx.x == 0 && blockIdx.x == 0) {
   tv2[0] = tv0[tv1[0]];
   tv3[0] = exp(tv2[0]);
}

The problem here is that tv0 is parallelized by threadIdx.x and blockIdx.x, so if it's stored in the register file, this is invalid, and furthermore the writes to tv0 on global memory must be flushed and all threads must be synchronized.

Alternatively, the below would work without using global memory:

 if (threadIdx.x + blockIdx.x * blockDim.x == tv1[0]) {
   tv2[0] = tv0[tv1[0]];
   tv3[0] = exp(tv2[0]);
}

Note that all dependent ops after tv2 also need to be predicated in the same way. To remove the predicate, it needs to do some type of communications to distribute the value to all other threads, for which one simple way is to save it to global memory and synchronize. So, if the ops following gather is trivial and doesn't need to be parallelized, this latter approach might make sense.

jacobhinkle commented 1 year ago

Interesting. In the special case where the gather output has a single use which is a reduction including the gather axis, rewriting the graph from

tv2 = torch_gather(tv1, tv_index, dim); // assume tv_index has size 1 in dimension dim
tv3 = sum(tv2, {dim});

to

tv4 = iota(tv1->axis(dim)->extent());
tv5 = broadcast(... // bcast along all dimensions other than dim
tv6 = where(eq(tv5, tv_index), tv1, zeroVal);  // forms a "one-hot" TensorView
tv7 = sum(tv6, {dim});

seems like it would compute the right thing even though the reduction is trivial. In that case we know tv6 has only the one use in the following sum, so the dim axis can be parallelized the same way as tv1, as it's just a product of pointwise ops.

jjsjann123 commented 1 year ago

A trivial question regarding this:

index.size(d) == input.size(d) for all dimensions d != dim

vs

index.size(d) <= input.size(d) for all dimensions d != dim

Can't the second part be handled with a Slice prior to the gather op? would having torch_gather gives us better code, comparing to a slice + somewhat simplified gather?

naoyam commented 1 year ago

Interesting. In the special case where the gather output has a single use which is a reduction including the gather axis, rewriting the graph from

tv2 = torch_gather(tv1, tv_index, dim); // assume tv_index has size 1 in dimension dim
tv3 = sum(tv2, {dim});

to

tv4 = iota(tv1->axis(dim)->extent());
tv5 = broadcast(... // bcast along all dimensions other than dim
tv6 = where(eq(tv5, tv_index), tv1, zeroVal);  // forms a "one-hot" TensorView
tv7 = sum(tv6, {dim});

seems like it would compute the right thing even though the reduction is trivial. In that case we know tv6 has only the one use in the following sum, so the dim axis can be parallelized the same way as tv1, as it's just a product of pointwise ops.

Oh, that translation is interesting. The final sum reduction would still need to be a grid reduction, even though most of the contribution by each thread should be zero, but it's certainly more efficient than writing to global memory, memory flush, global sync, and reading it again.

We should definitely try this formulation as well.

naoyam commented 1 year ago

A trivial question regarding this:

index.size(d) == input.size(d) for all dimensions d != dim

vs

index.size(d) <= input.size(d) for all dimensions d != dim

Can't the second part be handled with a Slice prior to the gather op? would having torch_gather gives us better code, comparing to a slice + somewhat simplified gather?

Ideally, it should not, but in reality I'm not sure. We haven't looked at the performance of these ops, so I'm pretty sure there's a lot to consider.