Closed sh-tada closed 9 months ago
I found that the derivative of the transmission spectra model (ArtTransPure) with respect to temperature results in NaN. Similarly, derivatives with respect to gravity_btm, radius_btm, and mean molecular weight also yield NaN.
Sample code
import jax from jax.config import config import pandas as pd import numpy as np import jax.numpy as jnp from exojax.utils.grids import wavenumber_grid from exojax.spec.opacalc import OpaPremodit from exojax.spec.atmrt import ArtTransPure from exojax.utils.constants import RJ, Rs from exojax.spec.api import MdbHitran from exojax.utils.astrofunc import gravity_jupiter from exojax.spec.unitconvert import wav2nu from exojax.spec.specop import SopRotation from exojax.spec.specop import SopInstProfile from exojax.utils.instfunc import resolution_to_gaussian_std config.update("jax_enable_x64", True) def read_data(filename): dat = pd.read_csv(filename, delimiter=" ") wav = dat["Wavelength[um]"] mask = (wav > 2.25) & (wav < 2.6) return wav[mask], dat["Rp/Rs"][mask] # Read data filename = "/home/kawahara/exojax/tests/integration/comparison/transmission/spectrum/CO100percent_500K.dat" wav, rprs = read_data(filename) inst_nus = wav2nu(np.array(wav), "um") # Model Nx = 300000 nu_grid, wav, res = wavenumber_grid(22900.0, 26000.0, Nx, unit="AA", xsmode="premodit") art = ArtTransPure(pressure_top=1.0e-15, pressure_btm=1.0e1, nlayer=100) art.change_temperature_range(490.0, 510.0) mdb = MdbHitran("CO", nu_grid, gpu_transfer=True, isotope=1) opa = OpaPremodit( mdb=mdb, nu_grid=nu_grid, auto_trange=[490, 510], dit_grid_resolution=1, ) sop_inst = SopInstProfile(nu_grid, res, vrmax=100.0) def model(params): mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV = params Tarr = T_fid * np.ones_like(art.pressure) mmr_arr = art.constant_mmr_profile(mmr_CO) mmw = mu_fid * np.ones_like(art.pressure) gravity = art.gravity_profile(Tarr, mmw, radius_btm, gravity_btm) xsmatrix = opa.xsmatrix(Tarr, art.pressure) dtau = art.opacity_profile_xs(xsmatrix, mmr_arr, opa.mdb.molmass, gravity) Rp2 = art.run(dtau, Tarr, mmw, radius_btm, gravity_btm) Rp2_sample = sop_inst.sampling(Rp2, RV, inst_nus) return jnp.sqrt(Rp2_sample) def objective(params): return jnp.sum((np.array(rprs[::-1]) - model(params)) ** 2) # Gradient grad = jax.grad(objective) params = np.array([1, 28.00863, 500, gravity_jupiter(1.0, 1.0), RJ, 0]) print() print("Parameters: mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV") print("Gradient", grad(params))
Result
Parameters: mmr_CO, mu_fid, T_fid, gravity_btm, radius_btm, RV Gradient [4.57663499e+00 nan nan nan nan 3.38162608e-03]
Thanks, It looks atmprof.nomralized_layer_height is not differentiable.
atmprof.nomralized_layer_height
I found that the derivative of the transmission spectra model (ArtTransPure) with respect to temperature results in NaN. Similarly, derivatives with respect to gravity_btm, radius_btm, and mean molecular weight also yield NaN.
Sample code
Result