We really want to avoid looping over tensors in C++ code wherever possible. And in particular looping for a number of iterations determined by a compute_shape call is worrisome. We want to do this logic in eDSL code so Tile & other MLIR optimizations are possible.
I think we could replace the core algorithm with something like the following to accomplish this:
auto I_gathered = gather(I, indices);
if (with_weights) {
std::vector<int64_t> unsqueeze_axes;
for (int64_t i = 2; i < I_gathered.rank(); i++) {
unsqueeze_axes.push_back(i);
}
auto weights_expanded = op::unsqueeze(per_sample_weights, unsqueeze_axes);
I_gathered = I_gathered * weights_expanded;
}
auto reduced = op::sum(I_gathered, edsl::make_tuple(1), false);
return edsl::make_tuple(reduced);
Great! This code snippet is very concise and elegant! Thank you for pointing this out.
As for #100, there's a parameter offsets defining variadic length of chunks(or the term segments?) very similar to #91 which makes it a bit more complicated. I'm not sure whether looping for a number of batch determined by offsets.size() can be avoided too. I will update the code for review later.
Great! This code snippet is very concise and elegant! Thank you for pointing this out.
As for #100, there's a parameter
offsets
defining variadic length of chunks(or the term segments?) very similar to #91 which makes it a bit more complicated. I'm not sure whether looping for a number ofbatch
determined byoffsets.size()
can be avoided too. I will update the code for review later.