eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
69 stars 16 forks source link

Refactor Configuration and API change #25

Open eelregit opened 10 months ago

eelregit commented 10 months ago

The units, constants, dtype settings, transfer settings, growth settings, and linear variance settings in Configuration couple more tightly with Cosmology, and should be moved there. After that boltzmann.py can also be independent of Configuration, and functions like jax-cosmo together with cosmology.py.

Likewise, I also try to make Particles more independent of conf. And have some different dtype design for Cosmology and Particles.

To accommodate new features, e.g., sCOLA-like parallelization and observables, the plan is to change the old API

conf = Configuration(...)
cosmo = Cosmology(conf, ...)
modes = white_noise(seed, conf)

def model(modes, ptcl, obsvbl, cosmo, solver):
    cosmo = boltzmann(cosmo, conf)
    modes = linear_modes(modes, cosmo, conf)
    ptcl, obsvbl = lpt(modes, cosmo, conf)
    ptcl, obsvbl = nbody(ptcl, obsvbl, cosmo, conf)
    return obj(obsvbl)

to

def config(...):
    cosmo = Cosmology(...)

    solver = Solver(...)  # the old Configuration, hierarchical, sub-solvers with boundary conditions (BCs)
    modes = white_noise(seed, solver)
    ptcl = pre_init_cond(...)  # or None, to be generated from solver, as additional input to lpt
    obsvbl = Snapshots(...), Lightcone(...)  # observable settings, as additional input to lpt

    return modes, ptcl, obsvbl, cosmo, solver

def model(modes, ptcl, obsvbl, cosmo, solver):
    cosmo = cosmo.prime()  # changed from boltzmann.boltzmann()
    #traverse solver (ptcl and modes too?) pytree (depth/breadth first? flatten is former)
    return obj(traverse(solve(modes, ptcl, obsvbl, cosmo, solver)))

def solve(modes, ptcl, obsvbl, cosmo, solver):
    modes = linear_modes(None, modes, cosmo, solver)  # a=None dependence is moved to front and made explicit
    ptcl, obsvbl, solver = lpt(modes, ptcl, obsvbl, cosmo, solver)  # compute BCs for all sub-solvers
    ptcl, obsvbl, solver = nbody(ptcl, obsvbl, cosmo, solver)  # lpt and nbody also calls obj_imdt -> obsvbl
    obsvbl, solver = obj(obsvbl, solver)  # compute current level obj -> obsvbl
    return obsvbl, solver

Discussions and comments are welcome.

eelregit commented 9 months ago

Distances added