Open KrisThielemans opened 1 week ago
This is due to #1549
@jit(parallel=True, nopython=True)
def kl_div_flat_mask(x, y, eta, mask):
accumulator = 0.
for i in prange(x.size):
if mask.flat[i] > 0:
X = x.flat[i]
Y = y.flat[i] + eta.flat[i]
if X > 0 and Y > 0:
# out.flat[i] = X * numpy.log(X/Y) - X + Y
accumulator += X * numpy.log(X/Y) - X + Y
elif X == 0 and Y >= 0:
# out.flat[i] = Y
accumulator += Y
else:
# out.flat[i] = numpy.inf
return numpy.inf
return accumulator
@njit(parallel=True, fastmath=True)
def kl_div_ravel_mask(x, y, eta, ind):
accumulator = 0.0
x = x.ravel()
y = y.ravel()
eta = eta.ravel()
for i in prange(ind.size):
tmp = y[i] + eta[i]
if x[i] > 0 and tmp > 0:
accumulator += x[i] * numpy.log(x[i] / tmp) - x[i] + tmp
elif x[i] == 0 and tmp >= 0:
accumulator += tmp
else:
accumulator = numpy.inf
return accumulator
2nd implementation is faster without any warnings. I can fix it if you want
https://github.com/SyneRBI/SIRF-SuperBuild/actions/runs/9569076152/job/26380772313#step:14:1264