I ended up putting the carrier_freq == 0.0 check into a try block to solve this. It's not the prettiest solution, but it seemed to be the simplest and most concise. I originally had checks for JAX types, then checked if carrier_freq was a tracer, but it was actually awkward to implement this.
The current code is essentially equivalent to:
if carrier_freq is not a JAX tracer object (or a tracer object inside of an Array), and carrier_freq == 0.0 then set self._is_constant = True.
The awkwardness with directly implementing this is that the code also needs to work if JAX isn't installed, so the "if carrier_freq is not a JAX tracer object" is itself an awkward thing to check.
I've also added tests verifying the original issue is resolved:
A check that carrier_freq can be properly traced within a JAX function if the envelope is constant.
A modification to an existing test to verify that is_constant is properly set within JAX tracing if carrier_freq is concretely 0.
Summary
Closes #245
Details and comments
I ended up putting the
carrier_freq == 0.0
check into atry
block to solve this. It's not the prettiest solution, but it seemed to be the simplest and most concise. I originally had checks for JAX types, then checked ifcarrier_freq
was a tracer, but it was actually awkward to implement this.The current code is essentially equivalent to:
carrier_freq
is not a JAX tracer object (or a tracer object inside of anArray
), andcarrier_freq == 0.0
then setself._is_constant = True
.The awkwardness with directly implementing this is that the code also needs to work if JAX isn't installed, so the "if
carrier_freq
is not a JAX tracer object" is itself an awkward thing to check.I've also added tests verifying the original issue is resolved:
carrier_freq
can be properly traced within a JAX function if the envelope is constant.is_constant
is properly set within JAX tracing ifcarrier_freq
is concretely0
.