LouisDesdoigts / dLux

Differentiable optical models as parameterised neural networks in Jax using Zodiax
https://louisdesdoigts.github.io/dLux/
BSD 3-Clause "New" or "Revised" License
49 stars 6 forks source link

`dlu.nandiv` fails on float inputs #269

Open LouisDesdoigts opened 3 months ago

LouisDesdoigts commented 3 months ago
dlu.nandiv(1.0, 0.0, -1)

ZeroDivisionError Traceback (most recent call last) Cell In[173], line 1 ----> 1 dlu.nandiv(1.0, 0.0, -1)

File ~/mambaforge/envs/amigo/lib/python3.11/site-packages/dLux/utils/math.py:77, in nandiv(a, b, fill) 59 def nandiv(a: Array, b: Array, fill: Any = np.inf) -> Array: 60 """ 61 Divides two arrays, replacing any NaNs with a fill value. 62 (...) 75 The result of the division. 76 """ ---> 77 return np.where(b == 0, fill, a / b)

ZeroDivisionError: float division by zero


whereas

dlu.nandiv(1., np.array(0.), -1)

Array(-1., dtype=float64, weak_type=True)


It should also probably be called safe_div