tedwards2412 / ripple

Differentiable Gravitational Waveforms with JAX
53 stars 15 forks source link

Example in readme does not work on fresh clone #3

Closed kazewong closed 2 years ago

kazewong commented 2 years ago

Here are the error messages

In [3]: from math import pi
   ...: import jax.numpy as jnp
   ...: 
   ...: from ripple.waveforms import IMRPhenomD, IMRPhenomD_utils
   ...: import matplotlib.pyplot as plt
   ...: from ripple import ms_to_Mc_eta
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

In [4]: # Get a frequency domain waveform
   ...: # source parameters
   ...: 
   ...: m1_msun = 20.0 # In solar masses
   ...: m2_msun = 19.0
   ...: chi1 = 0.5 # Dimensionless spin
   ...: chi2 = -0.5
   ...: tc = 0.0 # Time of coalescence in seconds
   ...: phic = 0.0 # Time of coalescence
   ...: dist_mpc = 440 # Distance to source in Mpc
   ...: inclination = 0.0 # Inclination Angle
   ...: polarization_angle = 0.2 # Polarization angle
   ...: 
   ...: # The PhenomD waveform model is parameterized with the chirp mass and symmetric mass ratio
   ...: Mc, eta = ms_to_Mc_eta(jnp.array([m1_msun, m2_msun]))
   ...: 
   ...: # These are the parametrs that go into the waveform generator
   ...: # Note that JAX does not give index errors, so if you pass in the
   ...: # the wrong array it will behave strangely
   ...: theta_ripple = jnp.array([Mc, eta, chi1, chi2, dist_mpc, tc, phic, inclination, polarization_angle])
   ...: 
   ...: # Now we need to generate the frequency grid
   ...: f_l = 24
   ...: f_u = 512
   ...: del_f = 0.01
   ...: fs = jnp.arange(f_l, f_u, del_f)
   ...: 
   ...: # And finally lets generate the waveform!
   ...: hp_ripple, hc_ripple = IMRPhenomD.gen_IMRPhenomD_polar(fs, theta_ripple)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-eb2cb0a333ad> in <module>
     27 
     28 # And finally lets generate the waveform!
---> 29 hp_ripple, hc_ripple = IMRPhenomD.gen_IMRPhenomD_polar(fs, theta_ripple)

    [... skipping hidden 13 frame]

~/Environment/GW/lib/python3.9/site-packages/ripple/waveforms/IMRPhenomD.py in gen_IMRPhenomD_polar(f, params)
    675     """
    676     l, psi = params[7], params[8]
--> 677     h0 = gen_IMRPhenomD(f, params)
    678 
    679     hp = h0 * (1 / 2 * (1 + jnp.cos(l) ** 2) * jnp.cos(2 * psi))

    [... skipping hidden 8 frame]

~/Environment/GW/lib/python3.9/site-packages/ripple/waveforms/IMRPhenomD.py in gen_IMRPhenomD(f, params)
    648 
    649     coeffs = get_coeffs(theta_intrinsic)
--> 650     h0 = _gen_IMRPhenomD(f, theta_intrinsic, theta_extrinsic, coeffs)
    651     return h0
    652 

    [... skipping hidden 8 frame]

~/Environment/GW/lib/python3.9/site-packages/ripple/waveforms/IMRPhenomD.py in _gen_IMRPhenomD(f, theta_intrinsic, theta_extrinsic, coeffs)
    599 
    600     # Shift phase so that peak amplitude matches t = 0
--> 601     _, _, _, f4, f_RD, f_damp = get_transition_frequencies(
    602         theta_intrinsic, coeffs[5], coeffs[6]
    603     )

~/Environment/GW/lib/python3.9/site-packages/ripple/waveforms/IMRPhenomD_utils.py in get_transition_frequencies(theta, gamma2, gamma3)
    103         f_RD_ + (f_damp_ * (-1 + jnp.sqrt(1 - (gamma2_) ** 2.0)) * gamma3_) / gamma2_
    104     )
--> 105     f4 = jax.lax.cond(
    106         gamma2 >= 1,
    107         f4_gammaneg_gtr_1,

    [... skipping hidden 1 frame]

~/Environment/GW/lib/python3.9/site-packages/jax/_src/lax/control_flow.py in cond(*args, **kwargs)
    758     return _cond_with_per_branch_args(*ba.args)
    759 
--> 760   return _cond(*args, **kwargs)
    761 
    762 def _cond_with_per_branch_args(pred,

TypeError: _cond() takes 4 positional arguments but 7 were given
kazewong commented 2 years ago

Retried on multiple machines and most of them work. Seems to be machine specific issue. Closing issue with comment

adam-coogan commented 2 years ago

I think this is related to the jax version you're using. cond used to require the fourth argument to be an object containing all the operands for true_fun and false_fun. That changed sometime in the past year. Anyways, this can stay closed!