HajimeKawahara / exojax

🐈 Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and JAXopt
http://secondearths.sakura.ne.jp/exojax/
MIT License
56 stars 14 forks source link

The temperature derivative of the transmission spectra model becomes NaN #463

Closed sh-tada closed 9 months ago

sh-tada commented 9 months ago

I found that the derivative of the transmission spectra model (ArtTransPure) with respect to temperature results in NaN. Similarly, derivatives with respect to gravity_btm, radius_btm, and mean molecular weight also yield NaN.

Sample code

import jax
from jax.config import config
import pandas as pd
import numpy as np
import jax.numpy as jnp
from exojax.utils.grids import wavenumber_grid
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.atmrt import ArtTransPure
from exojax.utils.constants import RJ, Rs
from exojax.spec.api import MdbHitran
from exojax.utils.astrofunc import gravity_jupiter

from exojax.spec.unitconvert import wav2nu
from exojax.spec.specop import SopRotation
from exojax.spec.specop import SopInstProfile
from exojax.utils.instfunc import resolution_to_gaussian_std

config.update("jax_enable_x64", True)

def read_data(filename):
    dat = pd.read_csv(filename, delimiter="   ")
    wav = dat["Wavelength[um]"]
    mask = (wav > 2.25) & (wav < 2.6)
    return wav[mask], dat["Rp/Rs"][mask]

# Read data
filename = "/home/kawahara/exojax/tests/integration/comparison/transmission/spectrum/CO100percent_500K.dat"
wav, rprs = read_data(filename)
inst_nus = wav2nu(np.array(wav), "um")

# Model
Nx = 300000
nu_grid, wav, res = wavenumber_grid(22900.0, 26000.0, Nx, unit="AA", xsmode="premodit")

art = ArtTransPure(pressure_top=1.0e-15, pressure_btm=1.0e1, nlayer=100)
art.change_temperature_range(490.0, 510.0)

mdb = MdbHitran("CO", nu_grid, gpu_transfer=True, isotope=1)
opa = OpaPremodit(
    mdb=mdb,
    nu_grid=nu_grid,
    auto_trange=[490, 510],
    dit_grid_resolution=1,
)

sop_inst = SopInstProfile(nu_grid, res, vrmax=100.0)

def model(params):
    mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV = params

    Tarr = T_fid * np.ones_like(art.pressure)
    mmr_arr = art.constant_mmr_profile(mmr_CO)

    mmw = mu_fid * np.ones_like(art.pressure)
    gravity = art.gravity_profile(Tarr, mmw, radius_btm, gravity_btm)

    xsmatrix = opa.xsmatrix(Tarr, art.pressure)
    dtau = art.opacity_profile_xs(xsmatrix, mmr_arr, opa.mdb.molmass, gravity)

    Rp2 = art.run(dtau, Tarr, mmw, radius_btm, gravity_btm)

    Rp2_sample = sop_inst.sampling(Rp2, RV, inst_nus)
    return jnp.sqrt(Rp2_sample)

def objective(params):
    return jnp.sum((np.array(rprs[::-1]) - model(params)) ** 2)

# Gradient
grad = jax.grad(objective)
params = np.array([1, 28.00863, 500, gravity_jupiter(1.0, 1.0), RJ, 0])
print()
print("Parameters: mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV")
print("Gradient", grad(params))

Result

Parameters: mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV
Gradient [4.57663499e+00            nan            nan            nan            nan 3.38162608e-03]
HajimeKawahara commented 9 months ago

Thanks, It looks atmprof.nomralized_layer_height is not differentiable.