Open manopapad opened 1 year ago
Here's what the function should look like. @manopapad can you make a pull request for it and run it through the tests?
template <typename VAL, int DIM>
static unsigned roundup_tile(Point<DIM>& tile,
const Point<DIM>& bounds,
const Point<DIM>& padding,
const unsigned max_size)
{
if (DIM == 1) {
// In this single case we can just solve for this directly
unsigned elements = max_size / sizeof(VAL);
assert(elements > padding[0]);
if (tile[0] < (elements - padding[0])) {
tile[0] = elements - padding[0];
if (bounds[0] < tile[0]) tile[0] = bounds[0];
}
return (tile[0] + padding[0]) * sizeof(VAL);
} else {
// Compute the initial size
// Shrink the tile to the bounds if necessary
unsigned result = sizeof(VAL);
for (int d = 0; d < DIM; d++) {
if (bounds[d] < tile[d]) tile[d] = bounds[d];
result *= (tile[d] + padding[d]);
}
// Find the two smallest dimensions and increase one of them
// until we hit the second smallest one or exceed max_smem_size
unsigned skipdims = 0;
bool all_same = true;
while (true) {
int d1 = DIM - 1, d2 = -1;
int t1 = tile[d1], t2 = 0;
while (t1 == bounds[d1]) {
skipdims |= (1 << d1);
if (--d1 < 0) break;
t1 = tile[d1];
}
for (int d = d1 - 1; d >= 0; d--) {
if (skipdims & (1 << d)) continue;
// Skip any dimension that is at its bound
if (tile[d] == bounds[d]) {
skipdims |= (1 << d);
continue;
}
if (tile[d] < t1) {
d2 = d1;
t2 = t1;
d1 = d;
t1 = tile[d];
} else if ((d2 < 0) || (tile[d] < t2)) {
d2 = d;
t2 = tile[d];
}
}
if (d2 == -1) {
// All the other dimensions are at their bounds, check that
// the last dimension is also at its bound if not solve
unsigned pitch = sizeof(VAL);
for (int d = 0; d < DIM; d++)
if (d != d1)
pitch *= (tile[d] + padding[d]);
// Make sure the last dimension is as large as it can go too
if (tile[d1] < bounds[d1]) {
unsigned elements = max_size / pitch;
assert(elements > padding[d1]);
assert(tile[d1] < (elements - padding[d1]));
tile[d1] = elements - padding[d1];
if (bounds[d1] < tile[d1]) tile[d1] = bounds[d1];
}
return pitch * (tile[d1] + padding[d1]);
}
// If we ever get two dimensions of the same size then see what dimension
// has the next largest value. If we can't find one that is larger then
// we know that there is no smallest dimension so we can march all the
// dimensions together at this point
if (t1 == t2) {
d2 = -1;
for (int d = 0; d < DIM; d++) {
if (d == d1) continue;
if (tile[d] <= tile[d1]) continue;
if ((d2 == -1) || (tile[d] < tile[d2])) {
d2 = d;
t2 = tile[d];
}
}
if (d2 == -1) break;
}
// Solve for the max we can walk
unsigned pitch = sizeof(VAL);
for (int d = 0; d < DIM; d++)
if (d != d1) pitch *= (tile[d] + padding[d]);
unsigned elements = max_size / pitch;
if ((elements <= padding[d1]) || (t1 >= (elements - padding[d1]))) {
skipdims |= (1 << d1);
continue;
}
unsigned bound = elements - padding[d1];
if (bounds[d1] < bound) {
tile[d1] = bounds[d1];
result = pitch * (tile[d1] + padding[d1]);
} else if (bound < t2) {
tile[d1] = bound;
result = pitch * (bound + padding[d1]);
all_same = false;
break;
} else {
tile[d1] = t2;
result = pitch * (t2 + padding[d1]);
}
}
if (all_same) {
// Step all the dimensions together until we hit
// the shared memory upper bound we're targetting
// This algorithm is in theory slow, but the max
// memory sizes of caches are "small" and the amount
// of memory will grow polynomially in the number
// of dimensions so it should converge quickly
while (true) {
unsigned next_size = sizeof(VAL);
for (int d = 0; d < DIM; d++)
if (skipdims & (1 << d))
next_size *= (tile[d] + padding[d]);
else if (tile[d] == bounds[d]) {
next_size *= (tile[d] + padding[d]);
skipdims |= (1 << d);
} else
next_size *= (tile[d] + 1 + padding[d]);
if ((next_size > max_size) || (next_size == result)) break;
result = next_size;
for (int d = 0; d < DIM; d++) {
if (skipdims && (1 << d)) continue;
tile[d]++;
}
}
}
return result;
}
}
The following input causes https://github.com/nv-legate/cunumeric/blob/branch-22.12/src/cunumeric/convolution/convolve_template.inl#L202 to go into an infinite loop:
This situation comes up when running
test_convolve.py
with 8 GPUs on a DGX-1 box.